vllm/vllm/config/speculative.py
Julien Denize d8c6210eea
Add Mistral Large 3 and Ministral 3 (#29757)
Signed-off-by: Julien Denize <julien.denize@mistral.ai>
Signed-off-by: Julien Denize <40604584+juliendenize@users.noreply.github.com>
Signed-off-by: Mickael Seznec <mickael@mistral.ai>
Signed-off-by: Roger Wang <hey@rogerw.io>
Co-authored-by: Roger Wang <hey@rogerw.io>
Co-authored-by: Mickael Seznec <mickael@mistral.ai>
2025-12-02 10:29:00 +00:00

644 lines
28 KiB
Python

# SPDX-License-Identifier: Apache-2.0
# SPDX-FileCopyrightText: Copyright contributors to the vLLM project
import ast
from typing import TYPE_CHECKING, Any, Literal, get_args
from pydantic import Field, SkipValidation, model_validator
from pydantic.dataclasses import dataclass
from typing_extensions import Self
from vllm.config.model import ModelConfig
from vllm.config.parallel import ParallelConfig
from vllm.config.utils import config
from vllm.logger import init_logger
from vllm.utils.hashing import safe_hash
from vllm.utils.import_utils import LazyLoader, has_arctic_inference
if TYPE_CHECKING:
from transformers import PretrainedConfig
import vllm.model_executor.layers.quantization as me_quant
else:
PretrainedConfig = Any
me_quant = LazyLoader(
"model_executor", globals(), "vllm.model_executor.layers.quantization"
)
logger = init_logger(__name__)
MTPModelTypes = Literal[
"deepseek_mtp",
"mimo_mtp",
"glm4_moe_mtp",
"ernie_mtp",
"qwen3_next_mtp",
"longcat_flash_mtp",
"mtp",
"pangu_ultra_moe_mtp",
]
EagleModelTypes = Literal["eagle", "eagle3", MTPModelTypes]
SpeculativeMethod = Literal[
"ngram",
"medusa",
"mlp_speculator",
"draft_model",
"suffix",
EagleModelTypes,
]
@config
@dataclass
class SpeculativeConfig:
"""Configuration for speculative decoding."""
enforce_eager: bool | None = None
"""Override the default enforce_eager from model_config"""
# General speculative decoding control
num_speculative_tokens: int = Field(default=None, gt=0)
"""The number of speculative tokens, if provided. It will default to the
number in the draft model config if present, otherwise, it is required."""
model: str | None = None
"""The name of the draft model, eagle head, or additional weights, if
provided."""
method: SpeculativeMethod | None = None
"""The name of the speculative method to use. If users provide and set the
`model` param, the speculative method type will be detected automatically
if possible, if `model` param is not provided, the method name must be
provided.
If using `ngram` method, the related configuration `prompt_lookup_max` and
`prompt_lookup_min` should be considered."""
draft_tensor_parallel_size: int | None = Field(default=None, ge=1)
"""The degree of the tensor parallelism for the draft model. Can only be 1
or the same as the target model's tensor parallel size."""
# Draft model configuration
quantization: me_quant.QuantizationMethods | None = None
"""Quantization method that was used to quantize the draft model weights.
If `None`, we assume the model weights are not quantized. Note that it only
takes effect when using the draft model-based speculative method."""
max_model_len: int | None = Field(default=None, ge=1)
"""The maximum model length of the draft model. Used when testing the
ability to skip speculation for some sequences."""
revision: str | None = None
"""The specific model version to use for the draft model. It can be a
branch name, a tag name, or a commit id. If unspecified, will use the
default version."""
code_revision: str | None = None
"""The specific revision to use for the draft model code on Hugging Face
Hub. It can be a branch name, a tag name, or a commit id. If unspecified,
will use the default version."""
# Advanced control
disable_by_batch_size: int | None = Field(default=None, ge=2)
"""Disable speculative decoding for new incoming requests when the number
of enqueued requests is larger than this value, if provided."""
disable_padded_drafter_batch: bool = False
"""Disable input padding for speculative decoding. If set to True,
speculative input batches can contain sequences of different lengths,
which may only be supported by certain attention backends. This currently
only affects the EAGLE method of speculation."""
# Ngram proposer configuration
prompt_lookup_max: int | None = Field(default=None, ge=1)
"""Maximum size of ngram token window when using Ngram proposer, required
when method is set to ngram."""
prompt_lookup_min: int | None = Field(default=None, ge=1)
"""Minimum size of ngram token window when using Ngram proposer, if
provided. Defaults to 1."""
speculative_token_tree: str | None = None
"""Specifies the tree structure for speculative token generation.
"""
# required configuration params passed from engine
target_model_config: SkipValidation[ModelConfig] = None # type: ignore
"""The configuration of the target model."""
target_parallel_config: SkipValidation[ParallelConfig] = None # type: ignore
"""The parallel configuration for the target model."""
# params generated in the post-init stage
draft_model_config: SkipValidation[ModelConfig] = None # type: ignore
"""The configuration of the draft model initialized internal."""
draft_parallel_config: SkipValidation[ParallelConfig] = None # type: ignore
"""The parallel configuration for the draft model initialized internal."""
# Suffix decoding configuration
suffix_decoding_max_tree_depth: int = 24
"""The maximum depth of the suffix decoding global and prompt trees. The
tree depth limits the sum of the prefix match and speculation lengths."""
suffix_decoding_max_cached_requests: int = 10000
"""The maximum number of requests to cache in the global suffix tree. If
exceeded, will trigger eviction in FIFO order. If set to 0, the global
suffix tree is disabled and past responses are not cached (prompt trees
are still used)."""
suffix_decoding_max_spec_factor: float = 1.0
"""The maximum spec factor for suffix decoding. The spec factor controls
speculation lengths based on the prefix match length: max_spec_tokens =
max_spec_factor * prefix_match_length."""
suffix_decoding_min_token_prob: float = 0.1
"""The minimum token probability for suffix decoding. Will only speculate
tokens with estimated probability (based on frequency counts) greater than
or equal to this value."""
def compute_hash(self) -> str:
"""
WARNING: Whenever a new field is added to this config,
ensure that it is included in the factors list if
it affects the computation graph.
Provide a hash that uniquely identifies all the configs
that affect the structure of the computation
graph from input ids/embeddings to the final hidden states,
excluding anything before input ids/embeddings and after
the final hidden states.
"""
factors: list[Any] = []
# Eagle3 affects the computation graph because it returns intermediate
# hidden states in addition to the final hidden state.
factors.append(self.method == "eagle3")
hash_str = safe_hash(str(factors).encode(), usedforsecurity=False).hexdigest()
return hash_str
@staticmethod
def hf_config_override(hf_config: PretrainedConfig) -> PretrainedConfig:
initial_architecture = hf_config.architectures[0]
if hf_config.model_type in ("deepseek_v3", "deepseek_v32"):
hf_config.model_type = "deepseek_mtp"
if hf_config.model_type == "deepseek_mtp":
n_predict = getattr(hf_config, "num_nextn_predict_layers", None)
hf_config.update(
{"n_predict": n_predict, "architectures": ["DeepSeekMTPModel"]}
)
if hf_config.model_type in ("pangu_ultra_moe"):
hf_config.model_type = "pangu_ultra_moe_mtp"
if hf_config.model_type == "pangu_ultra_moe_mtp":
n_predict = getattr(hf_config, "num_nextn_predict_layers", None)
hf_config.update(
{"n_predict": n_predict, "architectures": ["OpenPanguMTPModel"]}
)
if hf_config.architectures[0] == "MiMoForCausalLM":
hf_config.model_type = "mimo_mtp"
n_predict = getattr(hf_config, "num_nextn_predict_layers", None)
hf_config.update(
{
"num_hidden_layers": 0,
"n_predict": n_predict,
"architectures": ["MiMoMTPModel"],
}
)
if hf_config.architectures[0] == "Glm4MoeForCausalLM":
hf_config.model_type = "glm4_moe_mtp"
n_predict = getattr(hf_config, "num_nextn_predict_layers", None)
hf_config.update(
{
"num_hidden_layers": 0,
"n_predict": n_predict,
"architectures": ["Glm4MoeMTPModel"],
}
)
if hf_config.model_type == "ernie4_5_moe":
hf_config.model_type = "ernie_mtp"
if hf_config.model_type == "ernie_mtp":
n_predict = getattr(hf_config, "num_nextn_predict_layers", None)
hf_config.update(
{"n_predict": n_predict, "architectures": ["ErnieMTPModel"]}
)
if hf_config.model_type == "qwen3_next":
hf_config.model_type = "qwen3_next_mtp"
if hf_config.model_type == "qwen3_next_mtp":
n_predict = getattr(hf_config, "num_nextn_predict_layers", None)
hf_config.update(
{"n_predict": n_predict, "architectures": ["Qwen3NextMTP"]}
)
if hf_config.model_type == "longcat_flash":
hf_config.model_type = "longcat_flash_mtp"
n_predict = getattr(hf_config, "num_nextn_predict_layers", 1)
hf_config.update(
{"n_predict": n_predict, "architectures": ["LongCatFlashMTPModel"]}
)
if initial_architecture == "MistralLarge3ForCausalLM":
hf_config.update({"architectures": ["EagleMistralLarge3ForCausalLM"]})
return hf_config
def __post_init__(self):
# Note: "method" is a new parameter that helps to extend the
# configuration of non-model-based proposers, and the "model" parameter
# will be used to set the draft model, eagle head, or additional weight
# when needed. If users do not specify "method", the speculative method
# will be detected automatically if possible. If the speculative method
# can not be detected, it will be considered as the "draft_model" by
# default.
if self.method in get_args(MTPModelTypes) and self.method != "mtp":
logger.warning(
"method `%s` is deprecated and replaced with mtp.", self.method
)
self.method = "mtp"
if self.model is None and self.num_speculative_tokens is not None:
if self.method == "mtp":
if self.target_model_config is None:
raise ValueError("target_model_config must be present for mtp")
if self.target_model_config.hf_text_config.model_type == "deepseek_v32":
# FIXME(luccafong): cudgraph with v32 MTP is not supported,
# remove this when the issue is fixed.
self.enforce_eager = True
# use the draft model from the same model:
self.model = self.target_model_config.model
# Align the quantization of draft model for cases such as
# --quantization fp8 with a bf16 checkpoint.
if not self.quantization:
self.quantization = self.target_model_config.quantization
elif self.method in ("ngram", "[ngram]"):
self.model = "ngram"
elif self.method == "suffix":
self.model = "suffix"
else:
raise ValueError(
"num_speculative_tokens was provided but without speculative model."
)
# Automatically configure the method for ngram when "model" is used
# instead of "method"
if self.method is None and (
self.model is not None and self.model in ("ngram", "[ngram]")
):
self.method = "ngram"
if self.method in ("ngram", "[ngram]"):
# Unified to "ngram" internally
self.method = "ngram"
# Set default values if not provided
if self.prompt_lookup_min is None and self.prompt_lookup_max is None:
# TODO(woosuk): Tune these values. They are arbitrarily chosen.
self.prompt_lookup_min = 5
self.prompt_lookup_max = 5
elif self.prompt_lookup_min is None:
if self.prompt_lookup_max is None:
raise ValueError(
"Either prompt_lookup_max or prompt_lookup_min must be "
"provided when using the ngram method."
)
self.prompt_lookup_min = self.prompt_lookup_max
elif self.prompt_lookup_max is None:
if self.prompt_lookup_min is None:
raise ValueError(
"Either prompt_lookup_max or prompt_lookup_min must be "
"provided when using the ngram method."
)
self.prompt_lookup_max = self.prompt_lookup_min
# Validate values
if self.prompt_lookup_min > self.prompt_lookup_max:
raise ValueError(
f"prompt_lookup_min={self.prompt_lookup_min} must "
f"be <= prompt_lookup_max={self.prompt_lookup_max}"
)
# TODO: current we still need extract vocab_size from target model
# config, in future, we may try refactor it out, and set
# draft related config as None here.
self.draft_model_config = self.target_model_config
self.draft_parallel_config = self.target_parallel_config
elif self.method == "suffix":
self._validate_suffix_decoding()
else:
self.prompt_lookup_max = 0
self.prompt_lookup_min = 0
if self.model is not None:
self.draft_model_config = ModelConfig(
model=self.model,
runner="draft",
tokenizer=self.target_model_config.tokenizer,
tokenizer_mode=self.target_model_config.tokenizer_mode,
trust_remote_code=self.target_model_config.trust_remote_code,
allowed_local_media_path=self.target_model_config.allowed_local_media_path,
allowed_media_domains=self.target_model_config.allowed_media_domains,
dtype=self.target_model_config.dtype,
seed=self.target_model_config.seed,
revision=self.revision,
code_revision=self.code_revision,
tokenizer_revision=self.target_model_config.tokenizer_revision,
spec_target_max_model_len=self.target_model_config.max_model_len,
quantization=self.quantization,
enforce_eager=self.target_model_config.enforce_eager,
max_logprobs=self.target_model_config.max_logprobs,
hf_overrides=SpeculativeConfig.hf_config_override,
)
# Automatically detect the method
if self.method in ("eagle", "eagle3"):
pass
# examples:
# yuhuili/EAGLE-LLaMA3-Instruct-8B
# yuhuili/EAGLE3-LLaMA3.1-Instruct-8B
# AngelSlim/Qwen3-8B_eagle3
elif "eagle-" in self.draft_model_config.model.lower():
self.method = "eagle"
elif "eagle3" in self.draft_model_config.model.lower():
self.method = "eagle3"
elif self.draft_model_config.hf_config.model_type == "medusa":
self.method = "medusa"
elif self.draft_model_config.hf_config.model_type == "mlp_speculator":
self.method = "mlp_speculator"
elif self.draft_model_config.hf_config.model_type in get_args(
MTPModelTypes
):
self.method = "mtp"
if self.num_speculative_tokens > 1:
logger.warning(
"Enabling num_speculative_tokens > 1 will run"
"multiple times of forward on same MTP layer"
",which may result in lower acceptance rate"
)
elif self.draft_model_config.hf_config.model_type in (
"longcat_flash_mtp"
):
self.method = "longcat_flash_mtp"
if self.num_speculative_tokens > 1:
logger.warning(
"LongCat MTP models only have "
"one layer. Might need some code changes "
"to support multiple layers."
)
else:
self.method = "draft_model"
raise NotImplementedError(
"Speculative decoding with draft model is not "
"supported yet. Please consider using other "
"speculative decoding methods such as ngram, medusa, "
"eagle, or mtp."
)
# Replace hf_config for EAGLE draft_model
if self.method in ("eagle", "eagle3"):
from vllm.transformers_utils.configs import SpeculatorsConfig
from vllm.transformers_utils.configs.eagle import EAGLEConfig
if isinstance(
self.draft_model_config.hf_config,
(EAGLEConfig, SpeculatorsConfig),
):
pass
else:
eagle_config = EAGLEConfig(
self.draft_model_config.hf_config,
method=self.method,
model_type="eagle",
)
self.draft_model_config.hf_config = eagle_config
if self.num_speculative_tokens is not None and hasattr(
self.draft_model_config.hf_config, "num_lookahead_tokens"
):
self.draft_model_config.hf_config.num_lookahead_tokens = (
self.num_speculative_tokens
)
n_predict = getattr(
self.draft_model_config.hf_config, "n_predict", None
)
if n_predict is not None:
if self.num_speculative_tokens is None:
# Default to max value defined in draft model config.
self.num_speculative_tokens = n_predict
elif (
self.num_speculative_tokens > n_predict
and self.num_speculative_tokens % n_predict != 0
):
# Ensure divisibility for MTP module reuse.
raise ValueError(
f"num_speculative_tokens:{self.num_speculative_tokens}"
f" must be divisible by {n_predict=}"
)
if self.speculative_token_tree is None:
# Generate chain of tokens.
self.speculative_token_tree = str(
[(i + 1) * (0,) for i in range(self.num_speculative_tokens)]
)
else:
# Sort the token tree breadth-first.
tree_choices = ast.literal_eval(self.speculative_token_tree)
self.speculative_token_tree = str(
sorted(tree_choices, key=lambda t: (len(t), t))
)
self.draft_tensor_parallel_size = (
SpeculativeConfig._verify_and_get_draft_tp(
self.target_parallel_config,
self.draft_tensor_parallel_size,
self.draft_model_config.hf_config,
)
)
self.draft_model_config.max_model_len = (
SpeculativeConfig._maybe_override_draft_max_model_len(
self.max_model_len,
self.draft_model_config.max_model_len,
self.target_model_config.max_model_len,
)
)
self.draft_parallel_config = (
SpeculativeConfig.create_draft_parallel_config(
self.target_parallel_config, self.draft_tensor_parallel_size
)
)
return self
def _validate_suffix_decoding(self):
if not has_arctic_inference():
raise ImportError(
"Arctic Inference is required for suffix decoding. "
"Install via `pip install arctic-inference==0.1.1`."
)
if self.num_speculative_tokens is None:
# Suffix decoding decides the actual number of speculative tokens
# dynamically and treats num_speculative_tokens as a maximum limit.
self.num_speculative_tokens = self.suffix_decoding_max_tree_depth
logger.warning(
"Defaulted num_speculative_tokens to %s for suffix decoding.",
self.num_speculative_tokens,
)
# Validate values
if self.suffix_decoding_max_tree_depth < 1:
raise ValueError(
f"suffix_decoding_max_tree_depth="
f"{self.suffix_decoding_max_tree_depth} must be >= 1"
)
if self.suffix_decoding_max_cached_requests < 0:
raise ValueError(
f"suffix_decoding_max_cached_requests="
f"{self.suffix_decoding_max_cached_requests} must be >= 0"
)
if self.suffix_decoding_max_spec_factor < 0:
raise ValueError(
f"suffix_decoding_max_spec_factor="
f"{self.suffix_decoding_max_spec_factor} must be >= 0"
)
if not 0 <= self.suffix_decoding_min_token_prob <= 1:
raise ValueError(
f"suffix_decoding_min_token_prob="
f"{self.suffix_decoding_min_token_prob} must be in [0, 1]"
)
@staticmethod
def _maybe_override_draft_max_model_len(
speculative_max_model_len: int | None,
draft_max_model_len: int,
target_max_model_len: int,
) -> int:
"""Determine the max sequence len for the draft model. This is usually
the draft_max_model_len, but may be the target_max_model_len if it is
less than the draft_max_model_len, or may be speculative_max_model_len
if it is specified.
This is necessary so that sequences do not exceed the capacity of the
draft model or the target model.
speculative_max_model_len is mainly used for testing that sequences can
skip speculation.
"""
if speculative_max_model_len is not None:
if speculative_max_model_len > draft_max_model_len:
raise ValueError(
f"{speculative_max_model_len=} cannot be "
f"larger than {draft_max_model_len=}"
)
if speculative_max_model_len > target_max_model_len:
raise ValueError(
f"{speculative_max_model_len=} cannot be "
f"larger than {target_max_model_len=}"
)
return speculative_max_model_len
return min(
draft_max_model_len,
target_max_model_len,
)
@staticmethod
def _verify_and_get_draft_tp(
target_parallel_config: ParallelConfig,
speculative_draft_tensor_parallel_size: int | None,
draft_hf_config: PretrainedConfig,
) -> int:
"""
Verifies and adjusts the tensor parallel size for a draft model
specified using speculative_draft_tensor_parallel_size.
"""
# If speculative_draft_tensor_parallel_size is unset then set it
# appropriately else verify that it is set correctly.
if speculative_draft_tensor_parallel_size is None:
if draft_hf_config.model_type == "mlp_speculator":
speculative_draft_tensor_parallel_size = 1
if target_parallel_config.tensor_parallel_size > 1:
logger.warning(
"%s cannot currently be run with tp>1; "
"setting speculative_draft_tensor_parallel_size=1",
draft_hf_config.model_type,
)
else:
speculative_draft_tensor_parallel_size = (
target_parallel_config.tensor_parallel_size
)
elif speculative_draft_tensor_parallel_size not in (
1,
target_parallel_config.tensor_parallel_size,
):
raise ValueError(
f"{speculative_draft_tensor_parallel_size=} cannot be "
f"other value than 1 or target model tensor_parallel_size"
)
return speculative_draft_tensor_parallel_size
@staticmethod
def create_draft_parallel_config(
target_parallel_config: ParallelConfig,
speculative_draft_tensor_parallel_size: int,
) -> ParallelConfig:
"""Create a parallel config for use by the draft worker.
This is mostly a copy of the target parallel config, except the tp_size.
"""
draft_parallel_config = ParallelConfig(
pipeline_parallel_size=target_parallel_config.pipeline_parallel_size,
tensor_parallel_size=speculative_draft_tensor_parallel_size,
distributed_executor_backend=target_parallel_config.distributed_executor_backend,
max_parallel_loading_workers=target_parallel_config.max_parallel_loading_workers,
disable_custom_all_reduce=target_parallel_config.disable_custom_all_reduce,
ray_workers_use_nsight=target_parallel_config.ray_workers_use_nsight,
placement_group=target_parallel_config.placement_group,
)
return draft_parallel_config
@model_validator(mode="after")
def _verify_args(self) -> Self:
if self.num_speculative_tokens is None:
raise ValueError(
"num_speculative_tokens must be provided with "
"speculative model unless the draft model config contains an "
"n_predict parameter."
)
if self.num_speculative_tokens <= 0:
raise ValueError(
"Expected num_speculative_tokens to be greater "
f"than zero ({self.num_speculative_tokens})."
)
if self.draft_model_config:
self.draft_model_config.verify_with_parallel_config(
self.draft_parallel_config
)
if self.disable_by_batch_size is not None and self.disable_by_batch_size < 2:
raise ValueError(
"Expect the batch size threshold of disabling "
"speculative decoding is > 1, but got "
f"{self.disable_by_batch_size=}"
)
eagle3_target_supported = ["llama", "qwen", "minicpm", "gpt_oss"]
if (
self.method == "eagle3"
and self.target_model_config
and not any(
supported_model in self.target_model_config.hf_text_config.model_type
for supported_model in eagle3_target_supported
)
):
raise ValueError(
f"Eagle3 is only supported for {eagle3_target_supported} models. " # noqa: E501
f"Got {self.target_model_config.hf_text_config.model_type=}"
)
return self
def use_eagle(self) -> bool:
return self.method in ("eagle", "eagle3", "mtp")
def __repr__(self) -> str:
method = self.method
model = None if method in ("ngram", "suffix") else self.draft_model_config.model
num_spec_tokens = self.num_speculative_tokens
return f"SpeculativeConfig({method=}, {model=}, {num_spec_tokens=})"