Move SpeculativeConfig from config/__init__.py to config/speculative.py (#24904)

Signed-off-by: Harry Mellor <19981378+hmellor@users.noreply.github.com>
This commit is contained in:
Harry Mellor 2025-09-16 12:51:35 +01:00 committed by GitHub
parent 7ea5c73ad7
commit 0faf3cc3e8
No known key found for this signature in database
GPG Key ID: B5690EEEBB952194
2 changed files with 556 additions and 523 deletions

View File

@ -24,7 +24,7 @@ from pydantic import (ConfigDict, SkipValidation, field_validator,
model_validator) model_validator)
from pydantic.dataclasses import dataclass from pydantic.dataclasses import dataclass
from safetensors.torch import _TYPES as _SAFETENSORS_TO_TORCH_DTYPE from safetensors.torch import _TYPES as _SAFETENSORS_TO_TORCH_DTYPE
from typing_extensions import Self, assert_never, runtime_checkable from typing_extensions import assert_never, runtime_checkable
import vllm.envs as envs import vllm.envs as envs
from vllm import version from vllm import version
@ -41,6 +41,7 @@ from vllm.config.multimodal import (MMCacheType, MMEncoderTPMode,
from vllm.config.parallel import (DistributedExecutorBackend, EPLBConfig, from vllm.config.parallel import (DistributedExecutorBackend, EPLBConfig,
ParallelConfig) ParallelConfig)
from vllm.config.scheduler import SchedulerConfig, SchedulerPolicy from vllm.config.scheduler import SchedulerConfig, SchedulerPolicy
from vllm.config.speculative import SpeculativeConfig
from vllm.config.utils import ConfigType, config from vllm.config.utils import ConfigType, config
from vllm.logger import init_logger from vllm.logger import init_logger
from vllm.model_executor.layers.quantization import QuantizationMethods from vllm.model_executor.layers.quantization import QuantizationMethods
@ -1846,528 +1847,6 @@ class DeviceConfig:
self.device = torch.device(self.device_type) self.device = torch.device(self.device_type)
SpeculativeMethod = Literal["ngram", "eagle", "eagle3", "medusa",
"mlp_speculator", "draft_model", "deepseek_mtp",
"ernie_mtp", "qwen3_next_mtp"]
@config
@dataclass
class SpeculativeConfig:
"""Configuration for speculative decoding."""
# General speculative decoding control
num_speculative_tokens: SkipValidation[int] = None # type: ignore
"""The number of speculative tokens, if provided. It will default to the
number in the draft model config if present, otherwise, it is required."""
model: Optional[str] = None
"""The name of the draft model, eagle head, or additional weights, if
provided."""
method: Optional[SpeculativeMethod] = None
"""The name of the speculative method to use. If users provide and set the
`model` param, the speculative method type will be detected automatically
if possible, if `model` param is not provided, the method name must be
provided.
If using `ngram` method, the related configuration `prompt_lookup_max` and
`prompt_lookup_min` should be considered."""
draft_tensor_parallel_size: Optional[int] = None
"""The degree of the tensor parallelism for the draft model. Can only be 1
or the same as the target model's tensor parallel size."""
disable_logprobs: bool = True
"""If set to True, token log probabilities are not returned during
speculative decoding. If set to False, token log probabilities are returned
according to the log probability settings in SamplingParams."""
# Draft model configuration
quantization: Optional[me_quant.QuantizationMethods] = None
"""Quantization method that was used to quantize the draft model weights.
If `None`, we assume the model weights are not quantized. Note that it only
takes effect when using the draft model-based speculative method."""
max_model_len: Optional[int] = None
"""The maximum model length of the draft model. Used when testing the
ability to skip speculation for some sequences."""
revision: Optional[str] = None
"""The specific model version to use for the draft model. It can be a
branch name, a tag name, or a commit id. If unspecified, will use the
default version."""
code_revision: Optional[str] = None
"""The specific revision to use for the draft model code on Hugging Face
Hub. It can be a branch name, a tag name, or a commit id. If unspecified,
will use the default version."""
# Advanced control
disable_by_batch_size: Optional[int] = None
"""Disable speculative decoding for new incoming requests when the number
of enqueued requests is larger than this value, if provided."""
# Ngram proposer configuration
prompt_lookup_max: Optional[int] = None
"""Maximum size of ngram token window when using Ngram proposer, required
when method is set to ngram."""
prompt_lookup_min: Optional[int] = None
"""Minimum size of ngram token window when using Ngram proposer, if
provided. Defaults to 1."""
speculative_token_tree: Optional[str] = None
"""Specifies the tree structure for speculative token generation.
"""
# required configuration params passed from engine
target_model_config: SkipValidation[ModelConfig] = None # type: ignore
"""The configuration of the target model."""
target_parallel_config: SkipValidation[
ParallelConfig] = None # type: ignore
"""The parallel configuration for the target model."""
enable_chunked_prefill: SkipValidation[bool] = None # type: ignore
"""Whether vLLM is configured to use chunked prefill or not. Used for
raising an error since it's not yet compatible with speculative decode."""
disable_log_stats: SkipValidation[bool] = None # type: ignore
"""Whether to disable the periodic printing of stage times in speculative
decoding."""
# params generated in the post-init stage
draft_model_config: SkipValidation[ModelConfig] = None # type: ignore
"""The configuration of the draft model initialized internal."""
draft_parallel_config: SkipValidation[
ParallelConfig] = None # type: ignore
"""The parallel configuration for the draft model initialized internal."""
def compute_hash(self) -> str:
"""
WARNING: Whenever a new field is added to this config,
ensure that it is included in the factors list if
it affects the computation graph.
Provide a hash that uniquely identifies all the configs
that affect the structure of the computation
graph from input ids/embeddings to the final hidden states,
excluding anything before input ids/embeddings and after
the final hidden states.
"""
factors: list[Any] = []
# Eagle3 affects the computation graph because it returns intermediate
# hidden states in addition to the final hidden state.
factors.append(self.method == "eagle3")
hash_str = hashlib.md5(str(factors).encode(),
usedforsecurity=False).hexdigest()
return hash_str
@staticmethod
def hf_config_override(hf_config: PretrainedConfig) -> PretrainedConfig:
if hf_config.model_type == "deepseek_v3":
hf_config.model_type = "deepseek_mtp"
if hf_config.model_type == "deepseek_mtp":
n_predict = getattr(hf_config, "num_nextn_predict_layers", None)
hf_config.update({
"n_predict": n_predict,
"architectures": ["DeepSeekMTPModel"]
})
if hf_config.architectures[0] == "MiMoForCausalLM":
hf_config.model_type = "mimo_mtp"
n_predict = getattr(hf_config, "num_nextn_predict_layers", None)
hf_config.update({
"num_hidden_layers": 0,
"n_predict": n_predict,
"architectures": ["MiMoMTPModel"]
})
if hf_config.architectures[0] == "Glm4MoeForCausalLM":
hf_config.model_type = "glm4_moe_mtp"
n_predict = getattr(hf_config, "num_nextn_predict_layers", None)
hf_config.update({
"num_hidden_layers": 0,
"n_predict": n_predict,
"architectures": ["Glm4MoeMTPModel"]
})
if hf_config.model_type == "ernie4_5_moe":
hf_config.model_type = "ernie_mtp"
if hf_config.model_type == "ernie_mtp":
n_predict = getattr(hf_config, "num_nextn_predict_layers", None)
hf_config.update({
"n_predict": n_predict,
"architectures": ["ErnieMTPModel"]
})
if hf_config.model_type == "qwen3_next":
hf_config.model_type = "qwen3_next_mtp"
if hf_config.model_type == "qwen3_next_mtp":
n_predict = getattr(hf_config, "num_nextn_predict_layers", None)
hf_config.update({
"n_predict": n_predict,
"architectures": ["Qwen3NextMTP"]
})
return hf_config
def __post_init__(self):
# Note: "method" is a new parameter that helps to extend the
# configuration of non-model-based proposers, and the "model" parameter
# will be used to set the draft model, eagle head, or additional weight
# when needed. If users do not specify "method", the speculative method
# will be detected automatically if possible. If the speculative method
# can not be detected, it will be considered as the "draft_model" by
# default.
if self.model is None and self.num_speculative_tokens is not None:
# TODO(Shangming): Refactor mtp configuration logic when supporting
# mtp acceleration for more models besides deepseek_v3
if self.target_model_config and \
(self.target_model_config.hf_text_config.model_type \
== "deepseek_v3" or
self.target_model_config.hf_text_config.model_type in
("mimo","ernie4_5_moe", "qwen3_next")):
# use the draft model from the same model:
self.model = self.target_model_config.model
# Align the quantization of draft model for cases such as
# --quantization fp8 with a bf16 checkpoint.
if not self.quantization:
self.quantization = self.target_model_config.quantization
elif self.method in ("ngram", "[ngram]"):
self.model = "ngram"
else:
raise ValueError("num_speculative_tokens was provided without "
"speculative model.")
# Automatically configure the method for ngram when "model" is used
# instead of "method"
if self.method is None and (self.model is not None
and self.model in ("ngram", "[ngram]")):
self.method = "ngram"
if self.method in ("ngram", "[ngram]"):
# Unified to "ngram" internally
self.method = "ngram"
# Set default values if not provided
if (self.prompt_lookup_min is None
and self.prompt_lookup_max is None):
# TODO(woosuk): Tune these values. They are arbitrarily chosen.
self.prompt_lookup_min = 5
self.prompt_lookup_max = 5
elif self.prompt_lookup_min is None:
assert self.prompt_lookup_max is not None
self.prompt_lookup_min = self.prompt_lookup_max
elif self.prompt_lookup_max is None:
assert self.prompt_lookup_min is not None
self.prompt_lookup_max = self.prompt_lookup_min
# Validate values
if self.prompt_lookup_min < 1:
raise ValueError(
f"prompt_lookup_min={self.prompt_lookup_min} must be > 0")
if self.prompt_lookup_max < 1:
raise ValueError(
f"prompt_lookup_max={self.prompt_lookup_max} must be > 0")
if self.prompt_lookup_min > self.prompt_lookup_max:
raise ValueError(
f"prompt_lookup_min={self.prompt_lookup_min} must "
f"be <= prompt_lookup_max={self.prompt_lookup_max}")
# TODO: current we still need extract vocab_size from target model
# config, in future, we may try refactor it out, and set
# draft related config as None here.
self.draft_model_config = self.target_model_config
self.draft_parallel_config = self.target_parallel_config
else:
self.prompt_lookup_max = 0
self.prompt_lookup_min = 0
if self.model is not None:
self.draft_model_config = ModelConfig(
model=self.model,
runner="draft",
tokenizer=self.target_model_config.tokenizer,
tokenizer_mode=self.target_model_config.tokenizer_mode,
trust_remote_code=self.target_model_config.
trust_remote_code,
allowed_local_media_path=self.target_model_config.
allowed_local_media_path,
dtype=self.target_model_config.dtype,
seed=self.target_model_config.seed,
revision=self.revision,
code_revision=self.code_revision,
tokenizer_revision=self.target_model_config.
tokenizer_revision,
spec_target_max_model_len=self.target_model_config.
max_model_len,
quantization=self.quantization,
enforce_eager=self.target_model_config.enforce_eager,
max_seq_len_to_capture=self.target_model_config.
max_seq_len_to_capture,
max_logprobs=self.target_model_config.max_logprobs,
hf_overrides=SpeculativeConfig.hf_config_override,
)
# Automatically detect the method
if self.method in ('eagle', 'eagle3'):
pass
# examples:
# yuhuili/EAGLE-LLaMA3-Instruct-8B
# yuhuili/EAGLE3-LLaMA3.1-Instruct-8B
# AngelSlim/Qwen3-8B_eagle3
elif "eagle-" in self.draft_model_config.model.lower():
self.method = "eagle"
elif "eagle3" in self.draft_model_config.model.lower():
self.method = "eagle3"
elif self.draft_model_config.hf_config.model_type == "medusa":
self.method = "medusa"
elif (self.draft_model_config.hf_config.model_type ==
"mlp_speculator"):
self.method = "mlp_speculator"
elif (self.draft_model_config.hf_config.model_type
in ("deepseek_mtp", "mimo_mtp", "glm4_moe_mtp")):
self.method = "deepseek_mtp"
if self.num_speculative_tokens > 1:
logger.warning(
"All Deepseek MTP models only have " \
"one layer. Might need some code changes " \
"to support multiple layers."
)
elif (self.draft_model_config.hf_config.model_type ==
"ernie_mtp"):
self.method = "ernie_mtp"
if self.num_speculative_tokens > 1:
logger.warning(
"All Ernie MTP models only have " \
"one layer. Might need some code changes " \
"to support multiple layers."
)
elif (self.draft_model_config.hf_config.model_type ==
"qwen3_next_mtp"):
self.method = "qwen3_next_mtp"
if self.num_speculative_tokens > 1:
logger.warning(
"All Qwen3Next MTP models only have " \
"one layer. Might need some code changes " \
"to support multiple layers."
)
else:
self.method = "draft_model"
raise NotImplementedError(
"Speculative decoding with draft model is not "
"supported yet. Please consider using other "
"speculative decoding methods such as ngram, medusa, "
"eagle, or deepseek_mtp.")
# Replace hf_config for EAGLE draft_model
if self.method in ("eagle", "eagle3"):
if self.enable_chunked_prefill and not envs.VLLM_USE_V1:
raise ValueError(
"Chunked prefill and EAGLE are not compatible "
"when using V0.")
from vllm.transformers_utils.configs import (
SpeculatorsConfig)
from vllm.transformers_utils.configs.eagle import (
EAGLEConfig)
if isinstance(self.draft_model_config.hf_config,
(EAGLEConfig, SpeculatorsConfig)):
pass
else:
eagle_config = EAGLEConfig(
self.draft_model_config.hf_config,
method=self.method,
model_type="eagle")
self.draft_model_config.hf_config = eagle_config
if (self.num_speculative_tokens is not None
and hasattr(self.draft_model_config.hf_config,
"num_lookahead_tokens")):
self.draft_model_config.hf_config.num_lookahead_tokens = \
self.num_speculative_tokens
n_predict = getattr(self.draft_model_config.hf_config,
"n_predict", None)
if n_predict is not None:
if self.num_speculative_tokens is None:
# Default to max value defined in draft model config.
self.num_speculative_tokens = n_predict
elif self.num_speculative_tokens > n_predict and \
self.num_speculative_tokens % n_predict != 0:
# Ensure divisibility for MTP module reuse.
raise ValueError(
f"num_speculative_tokens:{self.num_speculative_tokens}"
f" must be divisible by {n_predict=}")
if self.speculative_token_tree is None:
# Generate chain of tokens.
self.speculative_token_tree = str([
(i + 1) * (0, )
for i in range(self.num_speculative_tokens)
])
else:
# Sort the token tree breadth-first.
tree_choices = ast.literal_eval(
self.speculative_token_tree)
self.speculative_token_tree = str(
sorted(tree_choices, key=lambda t: (len(t), t)))
self.draft_tensor_parallel_size = \
SpeculativeConfig._verify_and_get_draft_tp(
self.target_parallel_config,
self.draft_tensor_parallel_size,
self.draft_model_config.hf_config
)
self.draft_model_config.max_model_len = (
SpeculativeConfig._maybe_override_draft_max_model_len(
self.max_model_len,
self.draft_model_config.max_model_len,
self.target_model_config.max_model_len,
))
self.draft_parallel_config = (
SpeculativeConfig.create_draft_parallel_config(
self.target_parallel_config,
self.draft_tensor_parallel_size))
@staticmethod
def _maybe_override_draft_max_model_len(
speculative_max_model_len: Optional[int],
draft_max_model_len: int,
target_max_model_len: int,
) -> int:
"""Determine the max sequence len for the draft model. This is usually
the draft_max_model_len, but may be the target_max_model_len if it is
less than the draft_max_model_len, or may be speculative_max_model_len
if it is specified.
This is necessary so that sequences do not exceed the capacity of the
draft model or the target model.
speculative_max_model_len is mainly used for testing that sequences can
skip speculation.
"""
if speculative_max_model_len is not None:
if speculative_max_model_len > draft_max_model_len:
raise ValueError(f"{speculative_max_model_len=} cannot be "
f"larger than {draft_max_model_len=}")
if speculative_max_model_len > target_max_model_len:
raise ValueError(f"{speculative_max_model_len=} cannot be "
f"larger than {target_max_model_len=}")
return speculative_max_model_len
return min(
draft_max_model_len,
target_max_model_len,
)
@staticmethod
def _verify_and_get_draft_tp(
target_parallel_config: ParallelConfig,
speculative_draft_tensor_parallel_size: Optional[int],
draft_hf_config: PretrainedConfig) -> int:
"""
Verifies and adjusts the tensor parallel size for a draft model
specified using speculative_draft_tensor_parallel_size.
"""
# If speculative_draft_tensor_parallel_size is unset then set it
# appropriately else verify that it is set correctly.
if speculative_draft_tensor_parallel_size is None:
if draft_hf_config.model_type == "mlp_speculator":
speculative_draft_tensor_parallel_size = 1
if target_parallel_config.tensor_parallel_size > 1:
logger.warning(
"%s cannot currently be run with tp>1; "
"setting speculative_draft_tensor_parallel_size=1",
draft_hf_config.model_type)
else:
speculative_draft_tensor_parallel_size = \
target_parallel_config.tensor_parallel_size
elif speculative_draft_tensor_parallel_size not in (
1, target_parallel_config.tensor_parallel_size):
raise ValueError(
f"{speculative_draft_tensor_parallel_size=} cannot be "
f"other value than 1 or target model tensor_parallel_size")
return speculative_draft_tensor_parallel_size
@staticmethod
def create_draft_parallel_config(
target_parallel_config: ParallelConfig,
speculative_draft_tensor_parallel_size: int,
) -> ParallelConfig:
"""Create a parallel config for use by the draft worker.
This is mostly a copy of the target parallel config, except the tp_size.
"""
draft_parallel_config = ParallelConfig(
pipeline_parallel_size=target_parallel_config.
pipeline_parallel_size,
tensor_parallel_size=speculative_draft_tensor_parallel_size,
distributed_executor_backend=target_parallel_config.
distributed_executor_backend,
max_parallel_loading_workers=target_parallel_config.
max_parallel_loading_workers,
disable_custom_all_reduce=target_parallel_config.
disable_custom_all_reduce,
ray_workers_use_nsight=target_parallel_config.
ray_workers_use_nsight,
placement_group=target_parallel_config.placement_group,
)
return draft_parallel_config
@model_validator(mode='after')
def _verify_args(self) -> Self:
if self.num_speculative_tokens is None:
raise ValueError(
"num_speculative_tokens must be provided with "
"speculative model unless the draft model config contains an "
"n_predict parameter.")
if self.num_speculative_tokens <= 0:
raise ValueError("Expected num_speculative_tokens to be greater "
f"than zero ({self.num_speculative_tokens}).")
if self.draft_model_config:
self.draft_model_config.verify_with_parallel_config(
self.draft_parallel_config)
if (self.disable_by_batch_size is not None
and self.disable_by_batch_size < 2):
raise ValueError("Expect the batch size threshold of disabling "
"speculative decoding is > 1, but got "
f"{self.disable_by_batch_size=}")
eagle3_target_supported = ["llama", "qwen"]
if self.method == "eagle3" and self.target_model_config and not any(
supported_model in
self.target_model_config.hf_text_config.model_type
for supported_model in eagle3_target_supported):
raise ValueError(
f"Eagle3 is only supported for {eagle3_target_supported} models. " # noqa: E501
f"Got {self.target_model_config.hf_text_config.model_type=}")
return self
@property
def num_lookahead_slots(self) -> int:
"""The number of additional slots the scheduler should allocate per
step, in addition to the slots allocated for each known token.
This is equal to the number of speculative tokens, as each speculative
token must be scored.
"""
return self.num_speculative_tokens
def use_eagle(self) -> bool:
return self.method in ("eagle", "eagle3", "deepseek_mtp", "ernie_mtp",
"qwen3_next_mtp")
def __repr__(self) -> str:
method = self.method
model = None if method == "ngram" else self.draft_model_config.model
num_spec_tokens = self.num_speculative_tokens
return f"SpeculativeConfig({method=}, {model=}, {num_spec_tokens=})"
@config @config
@dataclass @dataclass
class PoolerConfig: class PoolerConfig:

554
vllm/config/speculative.py Normal file
View File

@ -0,0 +1,554 @@
# SPDX-License-Identifier: Apache-2.0
# SPDX-FileCopyrightText: Copyright contributors to the vLLM project
import ast
import hashlib
from typing import TYPE_CHECKING, Any, Literal, Optional
from pydantic import SkipValidation, model_validator
from pydantic.dataclasses import dataclass
from typing_extensions import Self
import vllm.envs as envs
from vllm.config.parallel import ParallelConfig
from vllm.config.utils import config
from vllm.logger import init_logger
from vllm.utils import LazyLoader
if TYPE_CHECKING:
from transformers import PretrainedConfig
import vllm.model_executor.layers.quantization as me_quant
from vllm.config import ModelConfig
else:
PretrainedConfig = Any
ModelConfig = Any
me_quant = LazyLoader("model_executor", globals(),
"vllm.model_executor.layers.quantization")
logger = init_logger(__name__)
SpeculativeMethod = Literal["ngram", "eagle", "eagle3", "medusa",
"mlp_speculator", "draft_model", "deepseek_mtp",
"ernie_mtp", "qwen3_next_mtp"]
@config
@dataclass
class SpeculativeConfig:
"""Configuration for speculative decoding."""
# General speculative decoding control
num_speculative_tokens: SkipValidation[int] = None # type: ignore
"""The number of speculative tokens, if provided. It will default to the
number in the draft model config if present, otherwise, it is required."""
model: Optional[str] = None
"""The name of the draft model, eagle head, or additional weights, if
provided."""
method: Optional[SpeculativeMethod] = None
"""The name of the speculative method to use. If users provide and set the
`model` param, the speculative method type will be detected automatically
if possible, if `model` param is not provided, the method name must be
provided.
If using `ngram` method, the related configuration `prompt_lookup_max` and
`prompt_lookup_min` should be considered."""
draft_tensor_parallel_size: Optional[int] = None
"""The degree of the tensor parallelism for the draft model. Can only be 1
or the same as the target model's tensor parallel size."""
disable_logprobs: bool = True
"""If set to True, token log probabilities are not returned during
speculative decoding. If set to False, token log probabilities are returned
according to the log probability settings in SamplingParams."""
# Draft model configuration
quantization: Optional[me_quant.QuantizationMethods] = None
"""Quantization method that was used to quantize the draft model weights.
If `None`, we assume the model weights are not quantized. Note that it only
takes effect when using the draft model-based speculative method."""
max_model_len: Optional[int] = None
"""The maximum model length of the draft model. Used when testing the
ability to skip speculation for some sequences."""
revision: Optional[str] = None
"""The specific model version to use for the draft model. It can be a
branch name, a tag name, or a commit id. If unspecified, will use the
default version."""
code_revision: Optional[str] = None
"""The specific revision to use for the draft model code on Hugging Face
Hub. It can be a branch name, a tag name, or a commit id. If unspecified,
will use the default version."""
# Advanced control
disable_by_batch_size: Optional[int] = None
"""Disable speculative decoding for new incoming requests when the number
of enqueued requests is larger than this value, if provided."""
# Ngram proposer configuration
prompt_lookup_max: Optional[int] = None
"""Maximum size of ngram token window when using Ngram proposer, required
when method is set to ngram."""
prompt_lookup_min: Optional[int] = None
"""Minimum size of ngram token window when using Ngram proposer, if
provided. Defaults to 1."""
speculative_token_tree: Optional[str] = None
"""Specifies the tree structure for speculative token generation.
"""
# required configuration params passed from engine
target_model_config: SkipValidation[ModelConfig] = None # type: ignore
"""The configuration of the target model."""
target_parallel_config: SkipValidation[
ParallelConfig] = None # type: ignore
"""The parallel configuration for the target model."""
enable_chunked_prefill: SkipValidation[bool] = None # type: ignore
"""Whether vLLM is configured to use chunked prefill or not. Used for
raising an error since it's not yet compatible with speculative decode."""
disable_log_stats: SkipValidation[bool] = None # type: ignore
"""Whether to disable the periodic printing of stage times in speculative
decoding."""
# params generated in the post-init stage
draft_model_config: SkipValidation[ModelConfig] = None # type: ignore
"""The configuration of the draft model initialized internal."""
draft_parallel_config: SkipValidation[
ParallelConfig] = None # type: ignore
"""The parallel configuration for the draft model initialized internal."""
def compute_hash(self) -> str:
"""
WARNING: Whenever a new field is added to this config,
ensure that it is included in the factors list if
it affects the computation graph.
Provide a hash that uniquely identifies all the configs
that affect the structure of the computation
graph from input ids/embeddings to the final hidden states,
excluding anything before input ids/embeddings and after
the final hidden states.
"""
factors: list[Any] = []
# Eagle3 affects the computation graph because it returns intermediate
# hidden states in addition to the final hidden state.
factors.append(self.method == "eagle3")
hash_str = hashlib.md5(str(factors).encode(),
usedforsecurity=False).hexdigest()
return hash_str
@staticmethod
def hf_config_override(hf_config: PretrainedConfig) -> PretrainedConfig:
if hf_config.model_type == "deepseek_v3":
hf_config.model_type = "deepseek_mtp"
if hf_config.model_type == "deepseek_mtp":
n_predict = getattr(hf_config, "num_nextn_predict_layers", None)
hf_config.update({
"n_predict": n_predict,
"architectures": ["DeepSeekMTPModel"]
})
if hf_config.architectures[0] == "MiMoForCausalLM":
hf_config.model_type = "mimo_mtp"
n_predict = getattr(hf_config, "num_nextn_predict_layers", None)
hf_config.update({
"num_hidden_layers": 0,
"n_predict": n_predict,
"architectures": ["MiMoMTPModel"]
})
if hf_config.architectures[0] == "Glm4MoeForCausalLM":
hf_config.model_type = "glm4_moe_mtp"
n_predict = getattr(hf_config, "num_nextn_predict_layers", None)
hf_config.update({
"num_hidden_layers": 0,
"n_predict": n_predict,
"architectures": ["Glm4MoeMTPModel"]
})
if hf_config.model_type == "ernie4_5_moe":
hf_config.model_type = "ernie_mtp"
if hf_config.model_type == "ernie_mtp":
n_predict = getattr(hf_config, "num_nextn_predict_layers", None)
hf_config.update({
"n_predict": n_predict,
"architectures": ["ErnieMTPModel"]
})
if hf_config.model_type == "qwen3_next":
hf_config.model_type = "qwen3_next_mtp"
if hf_config.model_type == "qwen3_next_mtp":
n_predict = getattr(hf_config, "num_nextn_predict_layers", None)
hf_config.update({
"n_predict": n_predict,
"architectures": ["Qwen3NextMTP"]
})
return hf_config
def __post_init__(self):
# Note: "method" is a new parameter that helps to extend the
# configuration of non-model-based proposers, and the "model" parameter
# will be used to set the draft model, eagle head, or additional weight
# when needed. If users do not specify "method", the speculative method
# will be detected automatically if possible. If the speculative method
# can not be detected, it will be considered as the "draft_model" by
# default.
if self.model is None and self.num_speculative_tokens is not None:
# TODO(Shangming): Refactor mtp configuration logic when supporting
# mtp acceleration for more models besides deepseek_v3
if self.target_model_config and \
(self.target_model_config.hf_text_config.model_type \
== "deepseek_v3" or
self.target_model_config.hf_text_config.model_type in
("mimo","ernie4_5_moe", "qwen3_next")):
# use the draft model from the same model:
self.model = self.target_model_config.model
# Align the quantization of draft model for cases such as
# --quantization fp8 with a bf16 checkpoint.
if not self.quantization:
self.quantization = self.target_model_config.quantization
elif self.method in ("ngram", "[ngram]"):
self.model = "ngram"
else:
raise ValueError("num_speculative_tokens was provided without "
"speculative model.")
# Automatically configure the method for ngram when "model" is used
# instead of "method"
if self.method is None and (self.model is not None
and self.model in ("ngram", "[ngram]")):
self.method = "ngram"
if self.method in ("ngram", "[ngram]"):
# Unified to "ngram" internally
self.method = "ngram"
# Set default values if not provided
if (self.prompt_lookup_min is None
and self.prompt_lookup_max is None):
# TODO(woosuk): Tune these values. They are arbitrarily chosen.
self.prompt_lookup_min = 5
self.prompt_lookup_max = 5
elif self.prompt_lookup_min is None:
assert self.prompt_lookup_max is not None
self.prompt_lookup_min = self.prompt_lookup_max
elif self.prompt_lookup_max is None:
assert self.prompt_lookup_min is not None
self.prompt_lookup_max = self.prompt_lookup_min
# Validate values
if self.prompt_lookup_min < 1:
raise ValueError(
f"prompt_lookup_min={self.prompt_lookup_min} must be > 0")
if self.prompt_lookup_max < 1:
raise ValueError(
f"prompt_lookup_max={self.prompt_lookup_max} must be > 0")
if self.prompt_lookup_min > self.prompt_lookup_max:
raise ValueError(
f"prompt_lookup_min={self.prompt_lookup_min} must "
f"be <= prompt_lookup_max={self.prompt_lookup_max}")
# TODO: current we still need extract vocab_size from target model
# config, in future, we may try refactor it out, and set
# draft related config as None here.
self.draft_model_config = self.target_model_config
self.draft_parallel_config = self.target_parallel_config
else:
self.prompt_lookup_max = 0
self.prompt_lookup_min = 0
if self.model is not None:
# TODO: Move this import to the top once `ModelConfig`
# lives in `vllm.config.model`.
from vllm.config import ModelConfig
self.draft_model_config = ModelConfig(
model=self.model,
runner="draft",
tokenizer=self.target_model_config.tokenizer,
tokenizer_mode=self.target_model_config.tokenizer_mode,
trust_remote_code=self.target_model_config.
trust_remote_code,
allowed_local_media_path=self.target_model_config.
allowed_local_media_path,
dtype=self.target_model_config.dtype,
seed=self.target_model_config.seed,
revision=self.revision,
code_revision=self.code_revision,
tokenizer_revision=self.target_model_config.
tokenizer_revision,
spec_target_max_model_len=self.target_model_config.
max_model_len,
quantization=self.quantization,
enforce_eager=self.target_model_config.enforce_eager,
max_seq_len_to_capture=self.target_model_config.
max_seq_len_to_capture,
max_logprobs=self.target_model_config.max_logprobs,
hf_overrides=SpeculativeConfig.hf_config_override,
)
# Automatically detect the method
if self.method in ('eagle', 'eagle3'):
pass
# examples:
# yuhuili/EAGLE-LLaMA3-Instruct-8B
# yuhuili/EAGLE3-LLaMA3.1-Instruct-8B
# AngelSlim/Qwen3-8B_eagle3
elif "eagle-" in self.draft_model_config.model.lower():
self.method = "eagle"
elif "eagle3" in self.draft_model_config.model.lower():
self.method = "eagle3"
elif self.draft_model_config.hf_config.model_type == "medusa":
self.method = "medusa"
elif (self.draft_model_config.hf_config.model_type ==
"mlp_speculator"):
self.method = "mlp_speculator"
elif (self.draft_model_config.hf_config.model_type
in ("deepseek_mtp", "mimo_mtp", "glm4_moe_mtp")):
self.method = "deepseek_mtp"
if self.num_speculative_tokens > 1:
logger.warning(
"All Deepseek MTP models only have " \
"one layer. Might need some code changes " \
"to support multiple layers."
)
elif (self.draft_model_config.hf_config.model_type ==
"ernie_mtp"):
self.method = "ernie_mtp"
if self.num_speculative_tokens > 1:
logger.warning(
"All Ernie MTP models only have " \
"one layer. Might need some code changes " \
"to support multiple layers."
)
elif (self.draft_model_config.hf_config.model_type ==
"qwen3_next_mtp"):
self.method = "qwen3_next_mtp"
if self.num_speculative_tokens > 1:
logger.warning(
"All Qwen3Next MTP models only have " \
"one layer. Might need some code changes " \
"to support multiple layers."
)
else:
self.method = "draft_model"
raise NotImplementedError(
"Speculative decoding with draft model is not "
"supported yet. Please consider using other "
"speculative decoding methods such as ngram, medusa, "
"eagle, or deepseek_mtp.")
# Replace hf_config for EAGLE draft_model
if self.method in ("eagle", "eagle3"):
if self.enable_chunked_prefill and not envs.VLLM_USE_V1:
raise ValueError(
"Chunked prefill and EAGLE are not compatible "
"when using V0.")
from vllm.transformers_utils.configs import (
SpeculatorsConfig)
from vllm.transformers_utils.configs.eagle import (
EAGLEConfig)
if isinstance(self.draft_model_config.hf_config,
(EAGLEConfig, SpeculatorsConfig)):
pass
else:
eagle_config = EAGLEConfig(
self.draft_model_config.hf_config,
method=self.method,
model_type="eagle")
self.draft_model_config.hf_config = eagle_config
if (self.num_speculative_tokens is not None
and hasattr(self.draft_model_config.hf_config,
"num_lookahead_tokens")):
self.draft_model_config.hf_config.num_lookahead_tokens = \
self.num_speculative_tokens
n_predict = getattr(self.draft_model_config.hf_config,
"n_predict", None)
if n_predict is not None:
if self.num_speculative_tokens is None:
# Default to max value defined in draft model config.
self.num_speculative_tokens = n_predict
elif self.num_speculative_tokens > n_predict and \
self.num_speculative_tokens % n_predict != 0:
# Ensure divisibility for MTP module reuse.
raise ValueError(
f"num_speculative_tokens:{self.num_speculative_tokens}"
f" must be divisible by {n_predict=}")
if self.speculative_token_tree is None:
# Generate chain of tokens.
self.speculative_token_tree = str([
(i + 1) * (0, )
for i in range(self.num_speculative_tokens)
])
else:
# Sort the token tree breadth-first.
tree_choices = ast.literal_eval(
self.speculative_token_tree)
self.speculative_token_tree = str(
sorted(tree_choices, key=lambda t: (len(t), t)))
self.draft_tensor_parallel_size = \
SpeculativeConfig._verify_and_get_draft_tp(
self.target_parallel_config,
self.draft_tensor_parallel_size,
self.draft_model_config.hf_config
)
self.draft_model_config.max_model_len = (
SpeculativeConfig._maybe_override_draft_max_model_len(
self.max_model_len,
self.draft_model_config.max_model_len,
self.target_model_config.max_model_len,
))
self.draft_parallel_config = (
SpeculativeConfig.create_draft_parallel_config(
self.target_parallel_config,
self.draft_tensor_parallel_size))
@staticmethod
def _maybe_override_draft_max_model_len(
speculative_max_model_len: Optional[int],
draft_max_model_len: int,
target_max_model_len: int,
) -> int:
"""Determine the max sequence len for the draft model. This is usually
the draft_max_model_len, but may be the target_max_model_len if it is
less than the draft_max_model_len, or may be speculative_max_model_len
if it is specified.
This is necessary so that sequences do not exceed the capacity of the
draft model or the target model.
speculative_max_model_len is mainly used for testing that sequences can
skip speculation.
"""
if speculative_max_model_len is not None:
if speculative_max_model_len > draft_max_model_len:
raise ValueError(f"{speculative_max_model_len=} cannot be "
f"larger than {draft_max_model_len=}")
if speculative_max_model_len > target_max_model_len:
raise ValueError(f"{speculative_max_model_len=} cannot be "
f"larger than {target_max_model_len=}")
return speculative_max_model_len
return min(
draft_max_model_len,
target_max_model_len,
)
@staticmethod
def _verify_and_get_draft_tp(
target_parallel_config: ParallelConfig,
speculative_draft_tensor_parallel_size: Optional[int],
draft_hf_config: PretrainedConfig) -> int:
"""
Verifies and adjusts the tensor parallel size for a draft model
specified using speculative_draft_tensor_parallel_size.
"""
# If speculative_draft_tensor_parallel_size is unset then set it
# appropriately else verify that it is set correctly.
if speculative_draft_tensor_parallel_size is None:
if draft_hf_config.model_type == "mlp_speculator":
speculative_draft_tensor_parallel_size = 1
if target_parallel_config.tensor_parallel_size > 1:
logger.warning(
"%s cannot currently be run with tp>1; "
"setting speculative_draft_tensor_parallel_size=1",
draft_hf_config.model_type)
else:
speculative_draft_tensor_parallel_size = \
target_parallel_config.tensor_parallel_size
elif speculative_draft_tensor_parallel_size not in (
1, target_parallel_config.tensor_parallel_size):
raise ValueError(
f"{speculative_draft_tensor_parallel_size=} cannot be "
f"other value than 1 or target model tensor_parallel_size")
return speculative_draft_tensor_parallel_size
@staticmethod
def create_draft_parallel_config(
target_parallel_config: ParallelConfig,
speculative_draft_tensor_parallel_size: int,
) -> ParallelConfig:
"""Create a parallel config for use by the draft worker.
This is mostly a copy of the target parallel config, except the tp_size.
"""
draft_parallel_config = ParallelConfig(
pipeline_parallel_size=target_parallel_config.
pipeline_parallel_size,
tensor_parallel_size=speculative_draft_tensor_parallel_size,
distributed_executor_backend=target_parallel_config.
distributed_executor_backend,
max_parallel_loading_workers=target_parallel_config.
max_parallel_loading_workers,
disable_custom_all_reduce=target_parallel_config.
disable_custom_all_reduce,
ray_workers_use_nsight=target_parallel_config.
ray_workers_use_nsight,
placement_group=target_parallel_config.placement_group,
)
return draft_parallel_config
@model_validator(mode='after')
def _verify_args(self) -> Self:
if self.num_speculative_tokens is None:
raise ValueError(
"num_speculative_tokens must be provided with "
"speculative model unless the draft model config contains an "
"n_predict parameter.")
if self.num_speculative_tokens <= 0:
raise ValueError("Expected num_speculative_tokens to be greater "
f"than zero ({self.num_speculative_tokens}).")
if self.draft_model_config:
self.draft_model_config.verify_with_parallel_config(
self.draft_parallel_config)
if (self.disable_by_batch_size is not None
and self.disable_by_batch_size < 2):
raise ValueError("Expect the batch size threshold of disabling "
"speculative decoding is > 1, but got "
f"{self.disable_by_batch_size=}")
eagle3_target_supported = ["llama", "qwen"]
if self.method == "eagle3" and self.target_model_config and not any(
supported_model in
self.target_model_config.hf_text_config.model_type
for supported_model in eagle3_target_supported):
raise ValueError(
f"Eagle3 is only supported for {eagle3_target_supported} models. " # noqa: E501
f"Got {self.target_model_config.hf_text_config.model_type=}")
return self
@property
def num_lookahead_slots(self) -> int:
"""The number of additional slots the scheduler should allocate per
step, in addition to the slots allocated for each known token.
This is equal to the number of speculative tokens, as each speculative
token must be scored.
"""
return self.num_speculative_tokens
def use_eagle(self) -> bool:
return self.method in ("eagle", "eagle3", "deepseek_mtp", "ernie_mtp",
"qwen3_next_mtp")
def __repr__(self) -> str:
method = self.method
model = None if method == "ngram" else self.draft_model_config.model
num_spec_tokens = self.num_speculative_tokens
return f"SpeculativeConfig({method=}, {model=}, {num_spec_tokens=})"