mirror of
https://git.datalinker.icu/vllm-project/vllm.git
synced 2025-12-10 02:55:40 +08:00
Move SpeculativeConfig from config/__init__.py to config/speculative.py (#24904)
Signed-off-by: Harry Mellor <19981378+hmellor@users.noreply.github.com>
This commit is contained in:
parent
7ea5c73ad7
commit
0faf3cc3e8
@ -24,7 +24,7 @@ from pydantic import (ConfigDict, SkipValidation, field_validator,
|
|||||||
model_validator)
|
model_validator)
|
||||||
from pydantic.dataclasses import dataclass
|
from pydantic.dataclasses import dataclass
|
||||||
from safetensors.torch import _TYPES as _SAFETENSORS_TO_TORCH_DTYPE
|
from safetensors.torch import _TYPES as _SAFETENSORS_TO_TORCH_DTYPE
|
||||||
from typing_extensions import Self, assert_never, runtime_checkable
|
from typing_extensions import assert_never, runtime_checkable
|
||||||
|
|
||||||
import vllm.envs as envs
|
import vllm.envs as envs
|
||||||
from vllm import version
|
from vllm import version
|
||||||
@ -41,6 +41,7 @@ from vllm.config.multimodal import (MMCacheType, MMEncoderTPMode,
|
|||||||
from vllm.config.parallel import (DistributedExecutorBackend, EPLBConfig,
|
from vllm.config.parallel import (DistributedExecutorBackend, EPLBConfig,
|
||||||
ParallelConfig)
|
ParallelConfig)
|
||||||
from vllm.config.scheduler import SchedulerConfig, SchedulerPolicy
|
from vllm.config.scheduler import SchedulerConfig, SchedulerPolicy
|
||||||
|
from vllm.config.speculative import SpeculativeConfig
|
||||||
from vllm.config.utils import ConfigType, config
|
from vllm.config.utils import ConfigType, config
|
||||||
from vllm.logger import init_logger
|
from vllm.logger import init_logger
|
||||||
from vllm.model_executor.layers.quantization import QuantizationMethods
|
from vllm.model_executor.layers.quantization import QuantizationMethods
|
||||||
@ -1846,528 +1847,6 @@ class DeviceConfig:
|
|||||||
self.device = torch.device(self.device_type)
|
self.device = torch.device(self.device_type)
|
||||||
|
|
||||||
|
|
||||||
SpeculativeMethod = Literal["ngram", "eagle", "eagle3", "medusa",
|
|
||||||
"mlp_speculator", "draft_model", "deepseek_mtp",
|
|
||||||
"ernie_mtp", "qwen3_next_mtp"]
|
|
||||||
|
|
||||||
|
|
||||||
@config
|
|
||||||
@dataclass
|
|
||||||
class SpeculativeConfig:
|
|
||||||
"""Configuration for speculative decoding."""
|
|
||||||
|
|
||||||
# General speculative decoding control
|
|
||||||
num_speculative_tokens: SkipValidation[int] = None # type: ignore
|
|
||||||
"""The number of speculative tokens, if provided. It will default to the
|
|
||||||
number in the draft model config if present, otherwise, it is required."""
|
|
||||||
model: Optional[str] = None
|
|
||||||
"""The name of the draft model, eagle head, or additional weights, if
|
|
||||||
provided."""
|
|
||||||
method: Optional[SpeculativeMethod] = None
|
|
||||||
"""The name of the speculative method to use. If users provide and set the
|
|
||||||
`model` param, the speculative method type will be detected automatically
|
|
||||||
if possible, if `model` param is not provided, the method name must be
|
|
||||||
provided.
|
|
||||||
|
|
||||||
If using `ngram` method, the related configuration `prompt_lookup_max` and
|
|
||||||
`prompt_lookup_min` should be considered."""
|
|
||||||
draft_tensor_parallel_size: Optional[int] = None
|
|
||||||
"""The degree of the tensor parallelism for the draft model. Can only be 1
|
|
||||||
or the same as the target model's tensor parallel size."""
|
|
||||||
disable_logprobs: bool = True
|
|
||||||
"""If set to True, token log probabilities are not returned during
|
|
||||||
speculative decoding. If set to False, token log probabilities are returned
|
|
||||||
according to the log probability settings in SamplingParams."""
|
|
||||||
|
|
||||||
# Draft model configuration
|
|
||||||
quantization: Optional[me_quant.QuantizationMethods] = None
|
|
||||||
"""Quantization method that was used to quantize the draft model weights.
|
|
||||||
If `None`, we assume the model weights are not quantized. Note that it only
|
|
||||||
takes effect when using the draft model-based speculative method."""
|
|
||||||
max_model_len: Optional[int] = None
|
|
||||||
"""The maximum model length of the draft model. Used when testing the
|
|
||||||
ability to skip speculation for some sequences."""
|
|
||||||
revision: Optional[str] = None
|
|
||||||
"""The specific model version to use for the draft model. It can be a
|
|
||||||
branch name, a tag name, or a commit id. If unspecified, will use the
|
|
||||||
default version."""
|
|
||||||
code_revision: Optional[str] = None
|
|
||||||
"""The specific revision to use for the draft model code on Hugging Face
|
|
||||||
Hub. It can be a branch name, a tag name, or a commit id. If unspecified,
|
|
||||||
will use the default version."""
|
|
||||||
|
|
||||||
# Advanced control
|
|
||||||
disable_by_batch_size: Optional[int] = None
|
|
||||||
"""Disable speculative decoding for new incoming requests when the number
|
|
||||||
of enqueued requests is larger than this value, if provided."""
|
|
||||||
|
|
||||||
# Ngram proposer configuration
|
|
||||||
prompt_lookup_max: Optional[int] = None
|
|
||||||
"""Maximum size of ngram token window when using Ngram proposer, required
|
|
||||||
when method is set to ngram."""
|
|
||||||
prompt_lookup_min: Optional[int] = None
|
|
||||||
"""Minimum size of ngram token window when using Ngram proposer, if
|
|
||||||
provided. Defaults to 1."""
|
|
||||||
|
|
||||||
speculative_token_tree: Optional[str] = None
|
|
||||||
"""Specifies the tree structure for speculative token generation.
|
|
||||||
"""
|
|
||||||
# required configuration params passed from engine
|
|
||||||
target_model_config: SkipValidation[ModelConfig] = None # type: ignore
|
|
||||||
"""The configuration of the target model."""
|
|
||||||
target_parallel_config: SkipValidation[
|
|
||||||
ParallelConfig] = None # type: ignore
|
|
||||||
"""The parallel configuration for the target model."""
|
|
||||||
enable_chunked_prefill: SkipValidation[bool] = None # type: ignore
|
|
||||||
"""Whether vLLM is configured to use chunked prefill or not. Used for
|
|
||||||
raising an error since it's not yet compatible with speculative decode."""
|
|
||||||
disable_log_stats: SkipValidation[bool] = None # type: ignore
|
|
||||||
"""Whether to disable the periodic printing of stage times in speculative
|
|
||||||
decoding."""
|
|
||||||
|
|
||||||
# params generated in the post-init stage
|
|
||||||
draft_model_config: SkipValidation[ModelConfig] = None # type: ignore
|
|
||||||
"""The configuration of the draft model initialized internal."""
|
|
||||||
draft_parallel_config: SkipValidation[
|
|
||||||
ParallelConfig] = None # type: ignore
|
|
||||||
"""The parallel configuration for the draft model initialized internal."""
|
|
||||||
|
|
||||||
def compute_hash(self) -> str:
|
|
||||||
"""
|
|
||||||
WARNING: Whenever a new field is added to this config,
|
|
||||||
ensure that it is included in the factors list if
|
|
||||||
it affects the computation graph.
|
|
||||||
|
|
||||||
Provide a hash that uniquely identifies all the configs
|
|
||||||
that affect the structure of the computation
|
|
||||||
graph from input ids/embeddings to the final hidden states,
|
|
||||||
excluding anything before input ids/embeddings and after
|
|
||||||
the final hidden states.
|
|
||||||
"""
|
|
||||||
factors: list[Any] = []
|
|
||||||
# Eagle3 affects the computation graph because it returns intermediate
|
|
||||||
# hidden states in addition to the final hidden state.
|
|
||||||
factors.append(self.method == "eagle3")
|
|
||||||
hash_str = hashlib.md5(str(factors).encode(),
|
|
||||||
usedforsecurity=False).hexdigest()
|
|
||||||
return hash_str
|
|
||||||
|
|
||||||
@staticmethod
|
|
||||||
def hf_config_override(hf_config: PretrainedConfig) -> PretrainedConfig:
|
|
||||||
if hf_config.model_type == "deepseek_v3":
|
|
||||||
hf_config.model_type = "deepseek_mtp"
|
|
||||||
if hf_config.model_type == "deepseek_mtp":
|
|
||||||
n_predict = getattr(hf_config, "num_nextn_predict_layers", None)
|
|
||||||
hf_config.update({
|
|
||||||
"n_predict": n_predict,
|
|
||||||
"architectures": ["DeepSeekMTPModel"]
|
|
||||||
})
|
|
||||||
|
|
||||||
if hf_config.architectures[0] == "MiMoForCausalLM":
|
|
||||||
hf_config.model_type = "mimo_mtp"
|
|
||||||
n_predict = getattr(hf_config, "num_nextn_predict_layers", None)
|
|
||||||
hf_config.update({
|
|
||||||
"num_hidden_layers": 0,
|
|
||||||
"n_predict": n_predict,
|
|
||||||
"architectures": ["MiMoMTPModel"]
|
|
||||||
})
|
|
||||||
|
|
||||||
if hf_config.architectures[0] == "Glm4MoeForCausalLM":
|
|
||||||
hf_config.model_type = "glm4_moe_mtp"
|
|
||||||
n_predict = getattr(hf_config, "num_nextn_predict_layers", None)
|
|
||||||
hf_config.update({
|
|
||||||
"num_hidden_layers": 0,
|
|
||||||
"n_predict": n_predict,
|
|
||||||
"architectures": ["Glm4MoeMTPModel"]
|
|
||||||
})
|
|
||||||
|
|
||||||
if hf_config.model_type == "ernie4_5_moe":
|
|
||||||
hf_config.model_type = "ernie_mtp"
|
|
||||||
if hf_config.model_type == "ernie_mtp":
|
|
||||||
n_predict = getattr(hf_config, "num_nextn_predict_layers", None)
|
|
||||||
hf_config.update({
|
|
||||||
"n_predict": n_predict,
|
|
||||||
"architectures": ["ErnieMTPModel"]
|
|
||||||
})
|
|
||||||
|
|
||||||
if hf_config.model_type == "qwen3_next":
|
|
||||||
hf_config.model_type = "qwen3_next_mtp"
|
|
||||||
if hf_config.model_type == "qwen3_next_mtp":
|
|
||||||
n_predict = getattr(hf_config, "num_nextn_predict_layers", None)
|
|
||||||
hf_config.update({
|
|
||||||
"n_predict": n_predict,
|
|
||||||
"architectures": ["Qwen3NextMTP"]
|
|
||||||
})
|
|
||||||
|
|
||||||
return hf_config
|
|
||||||
|
|
||||||
def __post_init__(self):
|
|
||||||
|
|
||||||
# Note: "method" is a new parameter that helps to extend the
|
|
||||||
# configuration of non-model-based proposers, and the "model" parameter
|
|
||||||
# will be used to set the draft model, eagle head, or additional weight
|
|
||||||
# when needed. If users do not specify "method", the speculative method
|
|
||||||
# will be detected automatically if possible. If the speculative method
|
|
||||||
# can not be detected, it will be considered as the "draft_model" by
|
|
||||||
# default.
|
|
||||||
|
|
||||||
if self.model is None and self.num_speculative_tokens is not None:
|
|
||||||
# TODO(Shangming): Refactor mtp configuration logic when supporting
|
|
||||||
# mtp acceleration for more models besides deepseek_v3
|
|
||||||
if self.target_model_config and \
|
|
||||||
(self.target_model_config.hf_text_config.model_type \
|
|
||||||
== "deepseek_v3" or
|
|
||||||
self.target_model_config.hf_text_config.model_type in
|
|
||||||
("mimo","ernie4_5_moe", "qwen3_next")):
|
|
||||||
# use the draft model from the same model:
|
|
||||||
self.model = self.target_model_config.model
|
|
||||||
# Align the quantization of draft model for cases such as
|
|
||||||
# --quantization fp8 with a bf16 checkpoint.
|
|
||||||
if not self.quantization:
|
|
||||||
self.quantization = self.target_model_config.quantization
|
|
||||||
elif self.method in ("ngram", "[ngram]"):
|
|
||||||
self.model = "ngram"
|
|
||||||
else:
|
|
||||||
raise ValueError("num_speculative_tokens was provided without "
|
|
||||||
"speculative model.")
|
|
||||||
|
|
||||||
# Automatically configure the method for ngram when "model" is used
|
|
||||||
# instead of "method"
|
|
||||||
if self.method is None and (self.model is not None
|
|
||||||
and self.model in ("ngram", "[ngram]")):
|
|
||||||
self.method = "ngram"
|
|
||||||
|
|
||||||
if self.method in ("ngram", "[ngram]"):
|
|
||||||
# Unified to "ngram" internally
|
|
||||||
self.method = "ngram"
|
|
||||||
# Set default values if not provided
|
|
||||||
if (self.prompt_lookup_min is None
|
|
||||||
and self.prompt_lookup_max is None):
|
|
||||||
# TODO(woosuk): Tune these values. They are arbitrarily chosen.
|
|
||||||
self.prompt_lookup_min = 5
|
|
||||||
self.prompt_lookup_max = 5
|
|
||||||
elif self.prompt_lookup_min is None:
|
|
||||||
assert self.prompt_lookup_max is not None
|
|
||||||
self.prompt_lookup_min = self.prompt_lookup_max
|
|
||||||
elif self.prompt_lookup_max is None:
|
|
||||||
assert self.prompt_lookup_min is not None
|
|
||||||
self.prompt_lookup_max = self.prompt_lookup_min
|
|
||||||
|
|
||||||
# Validate values
|
|
||||||
if self.prompt_lookup_min < 1:
|
|
||||||
raise ValueError(
|
|
||||||
f"prompt_lookup_min={self.prompt_lookup_min} must be > 0")
|
|
||||||
if self.prompt_lookup_max < 1:
|
|
||||||
raise ValueError(
|
|
||||||
f"prompt_lookup_max={self.prompt_lookup_max} must be > 0")
|
|
||||||
if self.prompt_lookup_min > self.prompt_lookup_max:
|
|
||||||
raise ValueError(
|
|
||||||
f"prompt_lookup_min={self.prompt_lookup_min} must "
|
|
||||||
f"be <= prompt_lookup_max={self.prompt_lookup_max}")
|
|
||||||
|
|
||||||
# TODO: current we still need extract vocab_size from target model
|
|
||||||
# config, in future, we may try refactor it out, and set
|
|
||||||
# draft related config as None here.
|
|
||||||
self.draft_model_config = self.target_model_config
|
|
||||||
self.draft_parallel_config = self.target_parallel_config
|
|
||||||
else:
|
|
||||||
self.prompt_lookup_max = 0
|
|
||||||
self.prompt_lookup_min = 0
|
|
||||||
|
|
||||||
if self.model is not None:
|
|
||||||
self.draft_model_config = ModelConfig(
|
|
||||||
model=self.model,
|
|
||||||
runner="draft",
|
|
||||||
tokenizer=self.target_model_config.tokenizer,
|
|
||||||
tokenizer_mode=self.target_model_config.tokenizer_mode,
|
|
||||||
trust_remote_code=self.target_model_config.
|
|
||||||
trust_remote_code,
|
|
||||||
allowed_local_media_path=self.target_model_config.
|
|
||||||
allowed_local_media_path,
|
|
||||||
dtype=self.target_model_config.dtype,
|
|
||||||
seed=self.target_model_config.seed,
|
|
||||||
revision=self.revision,
|
|
||||||
code_revision=self.code_revision,
|
|
||||||
tokenizer_revision=self.target_model_config.
|
|
||||||
tokenizer_revision,
|
|
||||||
spec_target_max_model_len=self.target_model_config.
|
|
||||||
max_model_len,
|
|
||||||
quantization=self.quantization,
|
|
||||||
enforce_eager=self.target_model_config.enforce_eager,
|
|
||||||
max_seq_len_to_capture=self.target_model_config.
|
|
||||||
max_seq_len_to_capture,
|
|
||||||
max_logprobs=self.target_model_config.max_logprobs,
|
|
||||||
hf_overrides=SpeculativeConfig.hf_config_override,
|
|
||||||
)
|
|
||||||
|
|
||||||
# Automatically detect the method
|
|
||||||
if self.method in ('eagle', 'eagle3'):
|
|
||||||
pass
|
|
||||||
# examples:
|
|
||||||
# yuhuili/EAGLE-LLaMA3-Instruct-8B
|
|
||||||
# yuhuili/EAGLE3-LLaMA3.1-Instruct-8B
|
|
||||||
# AngelSlim/Qwen3-8B_eagle3
|
|
||||||
elif "eagle-" in self.draft_model_config.model.lower():
|
|
||||||
self.method = "eagle"
|
|
||||||
elif "eagle3" in self.draft_model_config.model.lower():
|
|
||||||
self.method = "eagle3"
|
|
||||||
elif self.draft_model_config.hf_config.model_type == "medusa":
|
|
||||||
self.method = "medusa"
|
|
||||||
elif (self.draft_model_config.hf_config.model_type ==
|
|
||||||
"mlp_speculator"):
|
|
||||||
self.method = "mlp_speculator"
|
|
||||||
elif (self.draft_model_config.hf_config.model_type
|
|
||||||
in ("deepseek_mtp", "mimo_mtp", "glm4_moe_mtp")):
|
|
||||||
self.method = "deepseek_mtp"
|
|
||||||
if self.num_speculative_tokens > 1:
|
|
||||||
logger.warning(
|
|
||||||
"All Deepseek MTP models only have " \
|
|
||||||
"one layer. Might need some code changes " \
|
|
||||||
"to support multiple layers."
|
|
||||||
)
|
|
||||||
elif (self.draft_model_config.hf_config.model_type ==
|
|
||||||
"ernie_mtp"):
|
|
||||||
self.method = "ernie_mtp"
|
|
||||||
if self.num_speculative_tokens > 1:
|
|
||||||
logger.warning(
|
|
||||||
"All Ernie MTP models only have " \
|
|
||||||
"one layer. Might need some code changes " \
|
|
||||||
"to support multiple layers."
|
|
||||||
)
|
|
||||||
elif (self.draft_model_config.hf_config.model_type ==
|
|
||||||
"qwen3_next_mtp"):
|
|
||||||
self.method = "qwen3_next_mtp"
|
|
||||||
if self.num_speculative_tokens > 1:
|
|
||||||
logger.warning(
|
|
||||||
"All Qwen3Next MTP models only have " \
|
|
||||||
"one layer. Might need some code changes " \
|
|
||||||
"to support multiple layers."
|
|
||||||
)
|
|
||||||
else:
|
|
||||||
self.method = "draft_model"
|
|
||||||
raise NotImplementedError(
|
|
||||||
"Speculative decoding with draft model is not "
|
|
||||||
"supported yet. Please consider using other "
|
|
||||||
"speculative decoding methods such as ngram, medusa, "
|
|
||||||
"eagle, or deepseek_mtp.")
|
|
||||||
|
|
||||||
# Replace hf_config for EAGLE draft_model
|
|
||||||
if self.method in ("eagle", "eagle3"):
|
|
||||||
if self.enable_chunked_prefill and not envs.VLLM_USE_V1:
|
|
||||||
raise ValueError(
|
|
||||||
"Chunked prefill and EAGLE are not compatible "
|
|
||||||
"when using V0.")
|
|
||||||
|
|
||||||
from vllm.transformers_utils.configs import (
|
|
||||||
SpeculatorsConfig)
|
|
||||||
from vllm.transformers_utils.configs.eagle import (
|
|
||||||
EAGLEConfig)
|
|
||||||
|
|
||||||
if isinstance(self.draft_model_config.hf_config,
|
|
||||||
(EAGLEConfig, SpeculatorsConfig)):
|
|
||||||
pass
|
|
||||||
else:
|
|
||||||
eagle_config = EAGLEConfig(
|
|
||||||
self.draft_model_config.hf_config,
|
|
||||||
method=self.method,
|
|
||||||
model_type="eagle")
|
|
||||||
self.draft_model_config.hf_config = eagle_config
|
|
||||||
|
|
||||||
if (self.num_speculative_tokens is not None
|
|
||||||
and hasattr(self.draft_model_config.hf_config,
|
|
||||||
"num_lookahead_tokens")):
|
|
||||||
self.draft_model_config.hf_config.num_lookahead_tokens = \
|
|
||||||
self.num_speculative_tokens
|
|
||||||
|
|
||||||
n_predict = getattr(self.draft_model_config.hf_config,
|
|
||||||
"n_predict", None)
|
|
||||||
if n_predict is not None:
|
|
||||||
if self.num_speculative_tokens is None:
|
|
||||||
# Default to max value defined in draft model config.
|
|
||||||
self.num_speculative_tokens = n_predict
|
|
||||||
elif self.num_speculative_tokens > n_predict and \
|
|
||||||
self.num_speculative_tokens % n_predict != 0:
|
|
||||||
# Ensure divisibility for MTP module reuse.
|
|
||||||
raise ValueError(
|
|
||||||
f"num_speculative_tokens:{self.num_speculative_tokens}"
|
|
||||||
f" must be divisible by {n_predict=}")
|
|
||||||
|
|
||||||
if self.speculative_token_tree is None:
|
|
||||||
# Generate chain of tokens.
|
|
||||||
self.speculative_token_tree = str([
|
|
||||||
(i + 1) * (0, )
|
|
||||||
for i in range(self.num_speculative_tokens)
|
|
||||||
])
|
|
||||||
else:
|
|
||||||
# Sort the token tree breadth-first.
|
|
||||||
tree_choices = ast.literal_eval(
|
|
||||||
self.speculative_token_tree)
|
|
||||||
self.speculative_token_tree = str(
|
|
||||||
sorted(tree_choices, key=lambda t: (len(t), t)))
|
|
||||||
|
|
||||||
self.draft_tensor_parallel_size = \
|
|
||||||
SpeculativeConfig._verify_and_get_draft_tp(
|
|
||||||
self.target_parallel_config,
|
|
||||||
self.draft_tensor_parallel_size,
|
|
||||||
self.draft_model_config.hf_config
|
|
||||||
)
|
|
||||||
|
|
||||||
self.draft_model_config.max_model_len = (
|
|
||||||
SpeculativeConfig._maybe_override_draft_max_model_len(
|
|
||||||
self.max_model_len,
|
|
||||||
self.draft_model_config.max_model_len,
|
|
||||||
self.target_model_config.max_model_len,
|
|
||||||
))
|
|
||||||
|
|
||||||
self.draft_parallel_config = (
|
|
||||||
SpeculativeConfig.create_draft_parallel_config(
|
|
||||||
self.target_parallel_config,
|
|
||||||
self.draft_tensor_parallel_size))
|
|
||||||
|
|
||||||
@staticmethod
|
|
||||||
def _maybe_override_draft_max_model_len(
|
|
||||||
speculative_max_model_len: Optional[int],
|
|
||||||
draft_max_model_len: int,
|
|
||||||
target_max_model_len: int,
|
|
||||||
) -> int:
|
|
||||||
"""Determine the max sequence len for the draft model. This is usually
|
|
||||||
the draft_max_model_len, but may be the target_max_model_len if it is
|
|
||||||
less than the draft_max_model_len, or may be speculative_max_model_len
|
|
||||||
if it is specified.
|
|
||||||
|
|
||||||
This is necessary so that sequences do not exceed the capacity of the
|
|
||||||
draft model or the target model.
|
|
||||||
|
|
||||||
speculative_max_model_len is mainly used for testing that sequences can
|
|
||||||
skip speculation.
|
|
||||||
"""
|
|
||||||
|
|
||||||
if speculative_max_model_len is not None:
|
|
||||||
|
|
||||||
if speculative_max_model_len > draft_max_model_len:
|
|
||||||
raise ValueError(f"{speculative_max_model_len=} cannot be "
|
|
||||||
f"larger than {draft_max_model_len=}")
|
|
||||||
|
|
||||||
if speculative_max_model_len > target_max_model_len:
|
|
||||||
raise ValueError(f"{speculative_max_model_len=} cannot be "
|
|
||||||
f"larger than {target_max_model_len=}")
|
|
||||||
|
|
||||||
return speculative_max_model_len
|
|
||||||
|
|
||||||
return min(
|
|
||||||
draft_max_model_len,
|
|
||||||
target_max_model_len,
|
|
||||||
)
|
|
||||||
|
|
||||||
@staticmethod
|
|
||||||
def _verify_and_get_draft_tp(
|
|
||||||
target_parallel_config: ParallelConfig,
|
|
||||||
speculative_draft_tensor_parallel_size: Optional[int],
|
|
||||||
draft_hf_config: PretrainedConfig) -> int:
|
|
||||||
"""
|
|
||||||
Verifies and adjusts the tensor parallel size for a draft model
|
|
||||||
specified using speculative_draft_tensor_parallel_size.
|
|
||||||
"""
|
|
||||||
# If speculative_draft_tensor_parallel_size is unset then set it
|
|
||||||
# appropriately else verify that it is set correctly.
|
|
||||||
if speculative_draft_tensor_parallel_size is None:
|
|
||||||
if draft_hf_config.model_type == "mlp_speculator":
|
|
||||||
speculative_draft_tensor_parallel_size = 1
|
|
||||||
if target_parallel_config.tensor_parallel_size > 1:
|
|
||||||
logger.warning(
|
|
||||||
"%s cannot currently be run with tp>1; "
|
|
||||||
"setting speculative_draft_tensor_parallel_size=1",
|
|
||||||
draft_hf_config.model_type)
|
|
||||||
else:
|
|
||||||
speculative_draft_tensor_parallel_size = \
|
|
||||||
target_parallel_config.tensor_parallel_size
|
|
||||||
elif speculative_draft_tensor_parallel_size not in (
|
|
||||||
1, target_parallel_config.tensor_parallel_size):
|
|
||||||
raise ValueError(
|
|
||||||
f"{speculative_draft_tensor_parallel_size=} cannot be "
|
|
||||||
f"other value than 1 or target model tensor_parallel_size")
|
|
||||||
return speculative_draft_tensor_parallel_size
|
|
||||||
|
|
||||||
@staticmethod
|
|
||||||
def create_draft_parallel_config(
|
|
||||||
target_parallel_config: ParallelConfig,
|
|
||||||
speculative_draft_tensor_parallel_size: int,
|
|
||||||
) -> ParallelConfig:
|
|
||||||
"""Create a parallel config for use by the draft worker.
|
|
||||||
|
|
||||||
This is mostly a copy of the target parallel config, except the tp_size.
|
|
||||||
"""
|
|
||||||
draft_parallel_config = ParallelConfig(
|
|
||||||
pipeline_parallel_size=target_parallel_config.
|
|
||||||
pipeline_parallel_size,
|
|
||||||
tensor_parallel_size=speculative_draft_tensor_parallel_size,
|
|
||||||
distributed_executor_backend=target_parallel_config.
|
|
||||||
distributed_executor_backend,
|
|
||||||
max_parallel_loading_workers=target_parallel_config.
|
|
||||||
max_parallel_loading_workers,
|
|
||||||
disable_custom_all_reduce=target_parallel_config.
|
|
||||||
disable_custom_all_reduce,
|
|
||||||
ray_workers_use_nsight=target_parallel_config.
|
|
||||||
ray_workers_use_nsight,
|
|
||||||
placement_group=target_parallel_config.placement_group,
|
|
||||||
)
|
|
||||||
|
|
||||||
return draft_parallel_config
|
|
||||||
|
|
||||||
@model_validator(mode='after')
|
|
||||||
def _verify_args(self) -> Self:
|
|
||||||
if self.num_speculative_tokens is None:
|
|
||||||
raise ValueError(
|
|
||||||
"num_speculative_tokens must be provided with "
|
|
||||||
"speculative model unless the draft model config contains an "
|
|
||||||
"n_predict parameter.")
|
|
||||||
|
|
||||||
if self.num_speculative_tokens <= 0:
|
|
||||||
raise ValueError("Expected num_speculative_tokens to be greater "
|
|
||||||
f"than zero ({self.num_speculative_tokens}).")
|
|
||||||
|
|
||||||
if self.draft_model_config:
|
|
||||||
self.draft_model_config.verify_with_parallel_config(
|
|
||||||
self.draft_parallel_config)
|
|
||||||
|
|
||||||
if (self.disable_by_batch_size is not None
|
|
||||||
and self.disable_by_batch_size < 2):
|
|
||||||
raise ValueError("Expect the batch size threshold of disabling "
|
|
||||||
"speculative decoding is > 1, but got "
|
|
||||||
f"{self.disable_by_batch_size=}")
|
|
||||||
|
|
||||||
eagle3_target_supported = ["llama", "qwen"]
|
|
||||||
if self.method == "eagle3" and self.target_model_config and not any(
|
|
||||||
supported_model in
|
|
||||||
self.target_model_config.hf_text_config.model_type
|
|
||||||
for supported_model in eagle3_target_supported):
|
|
||||||
raise ValueError(
|
|
||||||
f"Eagle3 is only supported for {eagle3_target_supported} models. " # noqa: E501
|
|
||||||
f"Got {self.target_model_config.hf_text_config.model_type=}")
|
|
||||||
|
|
||||||
return self
|
|
||||||
|
|
||||||
@property
|
|
||||||
def num_lookahead_slots(self) -> int:
|
|
||||||
"""The number of additional slots the scheduler should allocate per
|
|
||||||
step, in addition to the slots allocated for each known token.
|
|
||||||
|
|
||||||
This is equal to the number of speculative tokens, as each speculative
|
|
||||||
token must be scored.
|
|
||||||
"""
|
|
||||||
return self.num_speculative_tokens
|
|
||||||
|
|
||||||
def use_eagle(self) -> bool:
|
|
||||||
return self.method in ("eagle", "eagle3", "deepseek_mtp", "ernie_mtp",
|
|
||||||
"qwen3_next_mtp")
|
|
||||||
|
|
||||||
def __repr__(self) -> str:
|
|
||||||
method = self.method
|
|
||||||
model = None if method == "ngram" else self.draft_model_config.model
|
|
||||||
num_spec_tokens = self.num_speculative_tokens
|
|
||||||
return f"SpeculativeConfig({method=}, {model=}, {num_spec_tokens=})"
|
|
||||||
|
|
||||||
|
|
||||||
@config
|
@config
|
||||||
@dataclass
|
@dataclass
|
||||||
class PoolerConfig:
|
class PoolerConfig:
|
||||||
|
|||||||
554
vllm/config/speculative.py
Normal file
554
vllm/config/speculative.py
Normal file
@ -0,0 +1,554 @@
|
|||||||
|
# SPDX-License-Identifier: Apache-2.0
|
||||||
|
# SPDX-FileCopyrightText: Copyright contributors to the vLLM project
|
||||||
|
|
||||||
|
import ast
|
||||||
|
import hashlib
|
||||||
|
from typing import TYPE_CHECKING, Any, Literal, Optional
|
||||||
|
|
||||||
|
from pydantic import SkipValidation, model_validator
|
||||||
|
from pydantic.dataclasses import dataclass
|
||||||
|
from typing_extensions import Self
|
||||||
|
|
||||||
|
import vllm.envs as envs
|
||||||
|
from vllm.config.parallel import ParallelConfig
|
||||||
|
from vllm.config.utils import config
|
||||||
|
from vllm.logger import init_logger
|
||||||
|
from vllm.utils import LazyLoader
|
||||||
|
|
||||||
|
if TYPE_CHECKING:
|
||||||
|
from transformers import PretrainedConfig
|
||||||
|
|
||||||
|
import vllm.model_executor.layers.quantization as me_quant
|
||||||
|
from vllm.config import ModelConfig
|
||||||
|
else:
|
||||||
|
PretrainedConfig = Any
|
||||||
|
ModelConfig = Any
|
||||||
|
|
||||||
|
me_quant = LazyLoader("model_executor", globals(),
|
||||||
|
"vllm.model_executor.layers.quantization")
|
||||||
|
|
||||||
|
logger = init_logger(__name__)
|
||||||
|
|
||||||
|
SpeculativeMethod = Literal["ngram", "eagle", "eagle3", "medusa",
|
||||||
|
"mlp_speculator", "draft_model", "deepseek_mtp",
|
||||||
|
"ernie_mtp", "qwen3_next_mtp"]
|
||||||
|
|
||||||
|
|
||||||
|
@config
|
||||||
|
@dataclass
|
||||||
|
class SpeculativeConfig:
|
||||||
|
"""Configuration for speculative decoding."""
|
||||||
|
|
||||||
|
# General speculative decoding control
|
||||||
|
num_speculative_tokens: SkipValidation[int] = None # type: ignore
|
||||||
|
"""The number of speculative tokens, if provided. It will default to the
|
||||||
|
number in the draft model config if present, otherwise, it is required."""
|
||||||
|
model: Optional[str] = None
|
||||||
|
"""The name of the draft model, eagle head, or additional weights, if
|
||||||
|
provided."""
|
||||||
|
method: Optional[SpeculativeMethod] = None
|
||||||
|
"""The name of the speculative method to use. If users provide and set the
|
||||||
|
`model` param, the speculative method type will be detected automatically
|
||||||
|
if possible, if `model` param is not provided, the method name must be
|
||||||
|
provided.
|
||||||
|
|
||||||
|
If using `ngram` method, the related configuration `prompt_lookup_max` and
|
||||||
|
`prompt_lookup_min` should be considered."""
|
||||||
|
draft_tensor_parallel_size: Optional[int] = None
|
||||||
|
"""The degree of the tensor parallelism for the draft model. Can only be 1
|
||||||
|
or the same as the target model's tensor parallel size."""
|
||||||
|
disable_logprobs: bool = True
|
||||||
|
"""If set to True, token log probabilities are not returned during
|
||||||
|
speculative decoding. If set to False, token log probabilities are returned
|
||||||
|
according to the log probability settings in SamplingParams."""
|
||||||
|
|
||||||
|
# Draft model configuration
|
||||||
|
quantization: Optional[me_quant.QuantizationMethods] = None
|
||||||
|
"""Quantization method that was used to quantize the draft model weights.
|
||||||
|
If `None`, we assume the model weights are not quantized. Note that it only
|
||||||
|
takes effect when using the draft model-based speculative method."""
|
||||||
|
max_model_len: Optional[int] = None
|
||||||
|
"""The maximum model length of the draft model. Used when testing the
|
||||||
|
ability to skip speculation for some sequences."""
|
||||||
|
revision: Optional[str] = None
|
||||||
|
"""The specific model version to use for the draft model. It can be a
|
||||||
|
branch name, a tag name, or a commit id. If unspecified, will use the
|
||||||
|
default version."""
|
||||||
|
code_revision: Optional[str] = None
|
||||||
|
"""The specific revision to use for the draft model code on Hugging Face
|
||||||
|
Hub. It can be a branch name, a tag name, or a commit id. If unspecified,
|
||||||
|
will use the default version."""
|
||||||
|
|
||||||
|
# Advanced control
|
||||||
|
disable_by_batch_size: Optional[int] = None
|
||||||
|
"""Disable speculative decoding for new incoming requests when the number
|
||||||
|
of enqueued requests is larger than this value, if provided."""
|
||||||
|
|
||||||
|
# Ngram proposer configuration
|
||||||
|
prompt_lookup_max: Optional[int] = None
|
||||||
|
"""Maximum size of ngram token window when using Ngram proposer, required
|
||||||
|
when method is set to ngram."""
|
||||||
|
prompt_lookup_min: Optional[int] = None
|
||||||
|
"""Minimum size of ngram token window when using Ngram proposer, if
|
||||||
|
provided. Defaults to 1."""
|
||||||
|
|
||||||
|
speculative_token_tree: Optional[str] = None
|
||||||
|
"""Specifies the tree structure for speculative token generation.
|
||||||
|
"""
|
||||||
|
# required configuration params passed from engine
|
||||||
|
target_model_config: SkipValidation[ModelConfig] = None # type: ignore
|
||||||
|
"""The configuration of the target model."""
|
||||||
|
target_parallel_config: SkipValidation[
|
||||||
|
ParallelConfig] = None # type: ignore
|
||||||
|
"""The parallel configuration for the target model."""
|
||||||
|
enable_chunked_prefill: SkipValidation[bool] = None # type: ignore
|
||||||
|
"""Whether vLLM is configured to use chunked prefill or not. Used for
|
||||||
|
raising an error since it's not yet compatible with speculative decode."""
|
||||||
|
disable_log_stats: SkipValidation[bool] = None # type: ignore
|
||||||
|
"""Whether to disable the periodic printing of stage times in speculative
|
||||||
|
decoding."""
|
||||||
|
|
||||||
|
# params generated in the post-init stage
|
||||||
|
draft_model_config: SkipValidation[ModelConfig] = None # type: ignore
|
||||||
|
"""The configuration of the draft model initialized internal."""
|
||||||
|
draft_parallel_config: SkipValidation[
|
||||||
|
ParallelConfig] = None # type: ignore
|
||||||
|
"""The parallel configuration for the draft model initialized internal."""
|
||||||
|
|
||||||
|
def compute_hash(self) -> str:
|
||||||
|
"""
|
||||||
|
WARNING: Whenever a new field is added to this config,
|
||||||
|
ensure that it is included in the factors list if
|
||||||
|
it affects the computation graph.
|
||||||
|
|
||||||
|
Provide a hash that uniquely identifies all the configs
|
||||||
|
that affect the structure of the computation
|
||||||
|
graph from input ids/embeddings to the final hidden states,
|
||||||
|
excluding anything before input ids/embeddings and after
|
||||||
|
the final hidden states.
|
||||||
|
"""
|
||||||
|
factors: list[Any] = []
|
||||||
|
# Eagle3 affects the computation graph because it returns intermediate
|
||||||
|
# hidden states in addition to the final hidden state.
|
||||||
|
factors.append(self.method == "eagle3")
|
||||||
|
hash_str = hashlib.md5(str(factors).encode(),
|
||||||
|
usedforsecurity=False).hexdigest()
|
||||||
|
return hash_str
|
||||||
|
|
||||||
|
@staticmethod
|
||||||
|
def hf_config_override(hf_config: PretrainedConfig) -> PretrainedConfig:
|
||||||
|
if hf_config.model_type == "deepseek_v3":
|
||||||
|
hf_config.model_type = "deepseek_mtp"
|
||||||
|
if hf_config.model_type == "deepseek_mtp":
|
||||||
|
n_predict = getattr(hf_config, "num_nextn_predict_layers", None)
|
||||||
|
hf_config.update({
|
||||||
|
"n_predict": n_predict,
|
||||||
|
"architectures": ["DeepSeekMTPModel"]
|
||||||
|
})
|
||||||
|
|
||||||
|
if hf_config.architectures[0] == "MiMoForCausalLM":
|
||||||
|
hf_config.model_type = "mimo_mtp"
|
||||||
|
n_predict = getattr(hf_config, "num_nextn_predict_layers", None)
|
||||||
|
hf_config.update({
|
||||||
|
"num_hidden_layers": 0,
|
||||||
|
"n_predict": n_predict,
|
||||||
|
"architectures": ["MiMoMTPModel"]
|
||||||
|
})
|
||||||
|
|
||||||
|
if hf_config.architectures[0] == "Glm4MoeForCausalLM":
|
||||||
|
hf_config.model_type = "glm4_moe_mtp"
|
||||||
|
n_predict = getattr(hf_config, "num_nextn_predict_layers", None)
|
||||||
|
hf_config.update({
|
||||||
|
"num_hidden_layers": 0,
|
||||||
|
"n_predict": n_predict,
|
||||||
|
"architectures": ["Glm4MoeMTPModel"]
|
||||||
|
})
|
||||||
|
|
||||||
|
if hf_config.model_type == "ernie4_5_moe":
|
||||||
|
hf_config.model_type = "ernie_mtp"
|
||||||
|
if hf_config.model_type == "ernie_mtp":
|
||||||
|
n_predict = getattr(hf_config, "num_nextn_predict_layers", None)
|
||||||
|
hf_config.update({
|
||||||
|
"n_predict": n_predict,
|
||||||
|
"architectures": ["ErnieMTPModel"]
|
||||||
|
})
|
||||||
|
|
||||||
|
if hf_config.model_type == "qwen3_next":
|
||||||
|
hf_config.model_type = "qwen3_next_mtp"
|
||||||
|
if hf_config.model_type == "qwen3_next_mtp":
|
||||||
|
n_predict = getattr(hf_config, "num_nextn_predict_layers", None)
|
||||||
|
hf_config.update({
|
||||||
|
"n_predict": n_predict,
|
||||||
|
"architectures": ["Qwen3NextMTP"]
|
||||||
|
})
|
||||||
|
|
||||||
|
return hf_config
|
||||||
|
|
||||||
|
def __post_init__(self):
|
||||||
|
|
||||||
|
# Note: "method" is a new parameter that helps to extend the
|
||||||
|
# configuration of non-model-based proposers, and the "model" parameter
|
||||||
|
# will be used to set the draft model, eagle head, or additional weight
|
||||||
|
# when needed. If users do not specify "method", the speculative method
|
||||||
|
# will be detected automatically if possible. If the speculative method
|
||||||
|
# can not be detected, it will be considered as the "draft_model" by
|
||||||
|
# default.
|
||||||
|
|
||||||
|
if self.model is None and self.num_speculative_tokens is not None:
|
||||||
|
# TODO(Shangming): Refactor mtp configuration logic when supporting
|
||||||
|
# mtp acceleration for more models besides deepseek_v3
|
||||||
|
if self.target_model_config and \
|
||||||
|
(self.target_model_config.hf_text_config.model_type \
|
||||||
|
== "deepseek_v3" or
|
||||||
|
self.target_model_config.hf_text_config.model_type in
|
||||||
|
("mimo","ernie4_5_moe", "qwen3_next")):
|
||||||
|
# use the draft model from the same model:
|
||||||
|
self.model = self.target_model_config.model
|
||||||
|
# Align the quantization of draft model for cases such as
|
||||||
|
# --quantization fp8 with a bf16 checkpoint.
|
||||||
|
if not self.quantization:
|
||||||
|
self.quantization = self.target_model_config.quantization
|
||||||
|
elif self.method in ("ngram", "[ngram]"):
|
||||||
|
self.model = "ngram"
|
||||||
|
else:
|
||||||
|
raise ValueError("num_speculative_tokens was provided without "
|
||||||
|
"speculative model.")
|
||||||
|
|
||||||
|
# Automatically configure the method for ngram when "model" is used
|
||||||
|
# instead of "method"
|
||||||
|
if self.method is None and (self.model is not None
|
||||||
|
and self.model in ("ngram", "[ngram]")):
|
||||||
|
self.method = "ngram"
|
||||||
|
|
||||||
|
if self.method in ("ngram", "[ngram]"):
|
||||||
|
# Unified to "ngram" internally
|
||||||
|
self.method = "ngram"
|
||||||
|
# Set default values if not provided
|
||||||
|
if (self.prompt_lookup_min is None
|
||||||
|
and self.prompt_lookup_max is None):
|
||||||
|
# TODO(woosuk): Tune these values. They are arbitrarily chosen.
|
||||||
|
self.prompt_lookup_min = 5
|
||||||
|
self.prompt_lookup_max = 5
|
||||||
|
elif self.prompt_lookup_min is None:
|
||||||
|
assert self.prompt_lookup_max is not None
|
||||||
|
self.prompt_lookup_min = self.prompt_lookup_max
|
||||||
|
elif self.prompt_lookup_max is None:
|
||||||
|
assert self.prompt_lookup_min is not None
|
||||||
|
self.prompt_lookup_max = self.prompt_lookup_min
|
||||||
|
|
||||||
|
# Validate values
|
||||||
|
if self.prompt_lookup_min < 1:
|
||||||
|
raise ValueError(
|
||||||
|
f"prompt_lookup_min={self.prompt_lookup_min} must be > 0")
|
||||||
|
if self.prompt_lookup_max < 1:
|
||||||
|
raise ValueError(
|
||||||
|
f"prompt_lookup_max={self.prompt_lookup_max} must be > 0")
|
||||||
|
if self.prompt_lookup_min > self.prompt_lookup_max:
|
||||||
|
raise ValueError(
|
||||||
|
f"prompt_lookup_min={self.prompt_lookup_min} must "
|
||||||
|
f"be <= prompt_lookup_max={self.prompt_lookup_max}")
|
||||||
|
|
||||||
|
# TODO: current we still need extract vocab_size from target model
|
||||||
|
# config, in future, we may try refactor it out, and set
|
||||||
|
# draft related config as None here.
|
||||||
|
self.draft_model_config = self.target_model_config
|
||||||
|
self.draft_parallel_config = self.target_parallel_config
|
||||||
|
else:
|
||||||
|
self.prompt_lookup_max = 0
|
||||||
|
self.prompt_lookup_min = 0
|
||||||
|
|
||||||
|
if self.model is not None:
|
||||||
|
# TODO: Move this import to the top once `ModelConfig`
|
||||||
|
# lives in `vllm.config.model`.
|
||||||
|
from vllm.config import ModelConfig
|
||||||
|
self.draft_model_config = ModelConfig(
|
||||||
|
model=self.model,
|
||||||
|
runner="draft",
|
||||||
|
tokenizer=self.target_model_config.tokenizer,
|
||||||
|
tokenizer_mode=self.target_model_config.tokenizer_mode,
|
||||||
|
trust_remote_code=self.target_model_config.
|
||||||
|
trust_remote_code,
|
||||||
|
allowed_local_media_path=self.target_model_config.
|
||||||
|
allowed_local_media_path,
|
||||||
|
dtype=self.target_model_config.dtype,
|
||||||
|
seed=self.target_model_config.seed,
|
||||||
|
revision=self.revision,
|
||||||
|
code_revision=self.code_revision,
|
||||||
|
tokenizer_revision=self.target_model_config.
|
||||||
|
tokenizer_revision,
|
||||||
|
spec_target_max_model_len=self.target_model_config.
|
||||||
|
max_model_len,
|
||||||
|
quantization=self.quantization,
|
||||||
|
enforce_eager=self.target_model_config.enforce_eager,
|
||||||
|
max_seq_len_to_capture=self.target_model_config.
|
||||||
|
max_seq_len_to_capture,
|
||||||
|
max_logprobs=self.target_model_config.max_logprobs,
|
||||||
|
hf_overrides=SpeculativeConfig.hf_config_override,
|
||||||
|
)
|
||||||
|
|
||||||
|
# Automatically detect the method
|
||||||
|
if self.method in ('eagle', 'eagle3'):
|
||||||
|
pass
|
||||||
|
# examples:
|
||||||
|
# yuhuili/EAGLE-LLaMA3-Instruct-8B
|
||||||
|
# yuhuili/EAGLE3-LLaMA3.1-Instruct-8B
|
||||||
|
# AngelSlim/Qwen3-8B_eagle3
|
||||||
|
elif "eagle-" in self.draft_model_config.model.lower():
|
||||||
|
self.method = "eagle"
|
||||||
|
elif "eagle3" in self.draft_model_config.model.lower():
|
||||||
|
self.method = "eagle3"
|
||||||
|
elif self.draft_model_config.hf_config.model_type == "medusa":
|
||||||
|
self.method = "medusa"
|
||||||
|
elif (self.draft_model_config.hf_config.model_type ==
|
||||||
|
"mlp_speculator"):
|
||||||
|
self.method = "mlp_speculator"
|
||||||
|
elif (self.draft_model_config.hf_config.model_type
|
||||||
|
in ("deepseek_mtp", "mimo_mtp", "glm4_moe_mtp")):
|
||||||
|
self.method = "deepseek_mtp"
|
||||||
|
if self.num_speculative_tokens > 1:
|
||||||
|
logger.warning(
|
||||||
|
"All Deepseek MTP models only have " \
|
||||||
|
"one layer. Might need some code changes " \
|
||||||
|
"to support multiple layers."
|
||||||
|
)
|
||||||
|
elif (self.draft_model_config.hf_config.model_type ==
|
||||||
|
"ernie_mtp"):
|
||||||
|
self.method = "ernie_mtp"
|
||||||
|
if self.num_speculative_tokens > 1:
|
||||||
|
logger.warning(
|
||||||
|
"All Ernie MTP models only have " \
|
||||||
|
"one layer. Might need some code changes " \
|
||||||
|
"to support multiple layers."
|
||||||
|
)
|
||||||
|
elif (self.draft_model_config.hf_config.model_type ==
|
||||||
|
"qwen3_next_mtp"):
|
||||||
|
self.method = "qwen3_next_mtp"
|
||||||
|
if self.num_speculative_tokens > 1:
|
||||||
|
logger.warning(
|
||||||
|
"All Qwen3Next MTP models only have " \
|
||||||
|
"one layer. Might need some code changes " \
|
||||||
|
"to support multiple layers."
|
||||||
|
)
|
||||||
|
else:
|
||||||
|
self.method = "draft_model"
|
||||||
|
raise NotImplementedError(
|
||||||
|
"Speculative decoding with draft model is not "
|
||||||
|
"supported yet. Please consider using other "
|
||||||
|
"speculative decoding methods such as ngram, medusa, "
|
||||||
|
"eagle, or deepseek_mtp.")
|
||||||
|
|
||||||
|
# Replace hf_config for EAGLE draft_model
|
||||||
|
if self.method in ("eagle", "eagle3"):
|
||||||
|
if self.enable_chunked_prefill and not envs.VLLM_USE_V1:
|
||||||
|
raise ValueError(
|
||||||
|
"Chunked prefill and EAGLE are not compatible "
|
||||||
|
"when using V0.")
|
||||||
|
|
||||||
|
from vllm.transformers_utils.configs import (
|
||||||
|
SpeculatorsConfig)
|
||||||
|
from vllm.transformers_utils.configs.eagle import (
|
||||||
|
EAGLEConfig)
|
||||||
|
|
||||||
|
if isinstance(self.draft_model_config.hf_config,
|
||||||
|
(EAGLEConfig, SpeculatorsConfig)):
|
||||||
|
pass
|
||||||
|
else:
|
||||||
|
eagle_config = EAGLEConfig(
|
||||||
|
self.draft_model_config.hf_config,
|
||||||
|
method=self.method,
|
||||||
|
model_type="eagle")
|
||||||
|
self.draft_model_config.hf_config = eagle_config
|
||||||
|
|
||||||
|
if (self.num_speculative_tokens is not None
|
||||||
|
and hasattr(self.draft_model_config.hf_config,
|
||||||
|
"num_lookahead_tokens")):
|
||||||
|
self.draft_model_config.hf_config.num_lookahead_tokens = \
|
||||||
|
self.num_speculative_tokens
|
||||||
|
|
||||||
|
n_predict = getattr(self.draft_model_config.hf_config,
|
||||||
|
"n_predict", None)
|
||||||
|
if n_predict is not None:
|
||||||
|
if self.num_speculative_tokens is None:
|
||||||
|
# Default to max value defined in draft model config.
|
||||||
|
self.num_speculative_tokens = n_predict
|
||||||
|
elif self.num_speculative_tokens > n_predict and \
|
||||||
|
self.num_speculative_tokens % n_predict != 0:
|
||||||
|
# Ensure divisibility for MTP module reuse.
|
||||||
|
raise ValueError(
|
||||||
|
f"num_speculative_tokens:{self.num_speculative_tokens}"
|
||||||
|
f" must be divisible by {n_predict=}")
|
||||||
|
|
||||||
|
if self.speculative_token_tree is None:
|
||||||
|
# Generate chain of tokens.
|
||||||
|
self.speculative_token_tree = str([
|
||||||
|
(i + 1) * (0, )
|
||||||
|
for i in range(self.num_speculative_tokens)
|
||||||
|
])
|
||||||
|
else:
|
||||||
|
# Sort the token tree breadth-first.
|
||||||
|
tree_choices = ast.literal_eval(
|
||||||
|
self.speculative_token_tree)
|
||||||
|
self.speculative_token_tree = str(
|
||||||
|
sorted(tree_choices, key=lambda t: (len(t), t)))
|
||||||
|
|
||||||
|
self.draft_tensor_parallel_size = \
|
||||||
|
SpeculativeConfig._verify_and_get_draft_tp(
|
||||||
|
self.target_parallel_config,
|
||||||
|
self.draft_tensor_parallel_size,
|
||||||
|
self.draft_model_config.hf_config
|
||||||
|
)
|
||||||
|
|
||||||
|
self.draft_model_config.max_model_len = (
|
||||||
|
SpeculativeConfig._maybe_override_draft_max_model_len(
|
||||||
|
self.max_model_len,
|
||||||
|
self.draft_model_config.max_model_len,
|
||||||
|
self.target_model_config.max_model_len,
|
||||||
|
))
|
||||||
|
|
||||||
|
self.draft_parallel_config = (
|
||||||
|
SpeculativeConfig.create_draft_parallel_config(
|
||||||
|
self.target_parallel_config,
|
||||||
|
self.draft_tensor_parallel_size))
|
||||||
|
|
||||||
|
@staticmethod
|
||||||
|
def _maybe_override_draft_max_model_len(
|
||||||
|
speculative_max_model_len: Optional[int],
|
||||||
|
draft_max_model_len: int,
|
||||||
|
target_max_model_len: int,
|
||||||
|
) -> int:
|
||||||
|
"""Determine the max sequence len for the draft model. This is usually
|
||||||
|
the draft_max_model_len, but may be the target_max_model_len if it is
|
||||||
|
less than the draft_max_model_len, or may be speculative_max_model_len
|
||||||
|
if it is specified.
|
||||||
|
|
||||||
|
This is necessary so that sequences do not exceed the capacity of the
|
||||||
|
draft model or the target model.
|
||||||
|
|
||||||
|
speculative_max_model_len is mainly used for testing that sequences can
|
||||||
|
skip speculation.
|
||||||
|
"""
|
||||||
|
|
||||||
|
if speculative_max_model_len is not None:
|
||||||
|
|
||||||
|
if speculative_max_model_len > draft_max_model_len:
|
||||||
|
raise ValueError(f"{speculative_max_model_len=} cannot be "
|
||||||
|
f"larger than {draft_max_model_len=}")
|
||||||
|
|
||||||
|
if speculative_max_model_len > target_max_model_len:
|
||||||
|
raise ValueError(f"{speculative_max_model_len=} cannot be "
|
||||||
|
f"larger than {target_max_model_len=}")
|
||||||
|
|
||||||
|
return speculative_max_model_len
|
||||||
|
|
||||||
|
return min(
|
||||||
|
draft_max_model_len,
|
||||||
|
target_max_model_len,
|
||||||
|
)
|
||||||
|
|
||||||
|
@staticmethod
|
||||||
|
def _verify_and_get_draft_tp(
|
||||||
|
target_parallel_config: ParallelConfig,
|
||||||
|
speculative_draft_tensor_parallel_size: Optional[int],
|
||||||
|
draft_hf_config: PretrainedConfig) -> int:
|
||||||
|
"""
|
||||||
|
Verifies and adjusts the tensor parallel size for a draft model
|
||||||
|
specified using speculative_draft_tensor_parallel_size.
|
||||||
|
"""
|
||||||
|
# If speculative_draft_tensor_parallel_size is unset then set it
|
||||||
|
# appropriately else verify that it is set correctly.
|
||||||
|
if speculative_draft_tensor_parallel_size is None:
|
||||||
|
if draft_hf_config.model_type == "mlp_speculator":
|
||||||
|
speculative_draft_tensor_parallel_size = 1
|
||||||
|
if target_parallel_config.tensor_parallel_size > 1:
|
||||||
|
logger.warning(
|
||||||
|
"%s cannot currently be run with tp>1; "
|
||||||
|
"setting speculative_draft_tensor_parallel_size=1",
|
||||||
|
draft_hf_config.model_type)
|
||||||
|
else:
|
||||||
|
speculative_draft_tensor_parallel_size = \
|
||||||
|
target_parallel_config.tensor_parallel_size
|
||||||
|
elif speculative_draft_tensor_parallel_size not in (
|
||||||
|
1, target_parallel_config.tensor_parallel_size):
|
||||||
|
raise ValueError(
|
||||||
|
f"{speculative_draft_tensor_parallel_size=} cannot be "
|
||||||
|
f"other value than 1 or target model tensor_parallel_size")
|
||||||
|
return speculative_draft_tensor_parallel_size
|
||||||
|
|
||||||
|
@staticmethod
|
||||||
|
def create_draft_parallel_config(
|
||||||
|
target_parallel_config: ParallelConfig,
|
||||||
|
speculative_draft_tensor_parallel_size: int,
|
||||||
|
) -> ParallelConfig:
|
||||||
|
"""Create a parallel config for use by the draft worker.
|
||||||
|
|
||||||
|
This is mostly a copy of the target parallel config, except the tp_size.
|
||||||
|
"""
|
||||||
|
draft_parallel_config = ParallelConfig(
|
||||||
|
pipeline_parallel_size=target_parallel_config.
|
||||||
|
pipeline_parallel_size,
|
||||||
|
tensor_parallel_size=speculative_draft_tensor_parallel_size,
|
||||||
|
distributed_executor_backend=target_parallel_config.
|
||||||
|
distributed_executor_backend,
|
||||||
|
max_parallel_loading_workers=target_parallel_config.
|
||||||
|
max_parallel_loading_workers,
|
||||||
|
disable_custom_all_reduce=target_parallel_config.
|
||||||
|
disable_custom_all_reduce,
|
||||||
|
ray_workers_use_nsight=target_parallel_config.
|
||||||
|
ray_workers_use_nsight,
|
||||||
|
placement_group=target_parallel_config.placement_group,
|
||||||
|
)
|
||||||
|
|
||||||
|
return draft_parallel_config
|
||||||
|
|
||||||
|
@model_validator(mode='after')
|
||||||
|
def _verify_args(self) -> Self:
|
||||||
|
if self.num_speculative_tokens is None:
|
||||||
|
raise ValueError(
|
||||||
|
"num_speculative_tokens must be provided with "
|
||||||
|
"speculative model unless the draft model config contains an "
|
||||||
|
"n_predict parameter.")
|
||||||
|
|
||||||
|
if self.num_speculative_tokens <= 0:
|
||||||
|
raise ValueError("Expected num_speculative_tokens to be greater "
|
||||||
|
f"than zero ({self.num_speculative_tokens}).")
|
||||||
|
|
||||||
|
if self.draft_model_config:
|
||||||
|
self.draft_model_config.verify_with_parallel_config(
|
||||||
|
self.draft_parallel_config)
|
||||||
|
|
||||||
|
if (self.disable_by_batch_size is not None
|
||||||
|
and self.disable_by_batch_size < 2):
|
||||||
|
raise ValueError("Expect the batch size threshold of disabling "
|
||||||
|
"speculative decoding is > 1, but got "
|
||||||
|
f"{self.disable_by_batch_size=}")
|
||||||
|
|
||||||
|
eagle3_target_supported = ["llama", "qwen"]
|
||||||
|
if self.method == "eagle3" and self.target_model_config and not any(
|
||||||
|
supported_model in
|
||||||
|
self.target_model_config.hf_text_config.model_type
|
||||||
|
for supported_model in eagle3_target_supported):
|
||||||
|
raise ValueError(
|
||||||
|
f"Eagle3 is only supported for {eagle3_target_supported} models. " # noqa: E501
|
||||||
|
f"Got {self.target_model_config.hf_text_config.model_type=}")
|
||||||
|
|
||||||
|
return self
|
||||||
|
|
||||||
|
@property
|
||||||
|
def num_lookahead_slots(self) -> int:
|
||||||
|
"""The number of additional slots the scheduler should allocate per
|
||||||
|
step, in addition to the slots allocated for each known token.
|
||||||
|
|
||||||
|
This is equal to the number of speculative tokens, as each speculative
|
||||||
|
token must be scored.
|
||||||
|
"""
|
||||||
|
return self.num_speculative_tokens
|
||||||
|
|
||||||
|
def use_eagle(self) -> bool:
|
||||||
|
return self.method in ("eagle", "eagle3", "deepseek_mtp", "ernie_mtp",
|
||||||
|
"qwen3_next_mtp")
|
||||||
|
|
||||||
|
def __repr__(self) -> str:
|
||||||
|
method = self.method
|
||||||
|
model = None if method == "ngram" else self.draft_model_config.model
|
||||||
|
num_spec_tokens = self.num_speculative_tokens
|
||||||
|
return f"SpeculativeConfig({method=}, {model=}, {num_spec_tokens=})"
|
||||||
Loading…
x
Reference in New Issue
Block a user