[Speculative Decoding] Add speculators config support (#21345)

This commit is contained in:
Dipika Sikka 2025-08-01 08:25:18 -04:00 committed by GitHub
parent 87c94bc879
commit dfbc1f8880
No known key found for this signature in database
GPG Key ID: B5690EEEBB952194
9 changed files with 232 additions and 11 deletions

View File

@ -0,0 +1,16 @@
# SPDX-License-Identifier: Apache-2.0
# SPDX-FileCopyrightText: Copyright contributors to the vLLM project
import pytest
import torch
@pytest.mark.parametrize(
"model_path",
[("nm-testing/SpeculatorLlama3-1-8B-Eagle3-converted-0717"),
("nm-testing/SpeculatorLlama3-1-8B-Eagle3-converted-0717-quantized")])
def test_llama(vllm_runner, example_prompts, model_path):
with vllm_runner(model_path, dtype=torch.bfloat16) as vllm_model:
vllm_outputs = vllm_model.generate_greedy(example_prompts,
max_tokens=20)
print(vllm_outputs)
assert vllm_outputs

View File

@ -39,8 +39,8 @@ from vllm.transformers_utils.config import (
ConfigFormat, get_config, get_hf_image_processor_config,
get_hf_text_config, get_pooling_config,
get_sentence_transformer_tokenizer_config, is_encoder_decoder,
try_get_generation_config, try_get_safetensors_metadata,
try_get_tokenizer_config, uses_mrope)
maybe_override_with_speculators_target_model, try_get_generation_config,
try_get_safetensors_metadata, try_get_tokenizer_config, uses_mrope)
from vllm.transformers_utils.s3_utils import S3Model
from vllm.transformers_utils.utils import is_s3, maybe_model_redirect
# yapf conflicts with isort for this block
@ -535,6 +535,15 @@ class ModelConfig:
"affect the random state of the Python process that "
"launched vLLM.", self.seed)
if self.runner != "draft":
# If we're not running the draft model, check for speculators config
# If speculators config, set model / tokenizer to be target model
self.model, self.tokenizer = maybe_override_with_speculators_target_model( # noqa: E501
model=self.model,
tokenizer=self.tokenizer,
revision=self.revision,
trust_remote_code=self.trust_remote_code)
# Keep set served_model_name before maybe_model_redirect(self.model)
self.served_model_name = get_served_model_name(self.model,
self.served_model_name)
@ -606,8 +615,8 @@ class ModelConfig:
self.config_format,
hf_overrides_kw=hf_overrides_kw,
hf_overrides_fn=hf_overrides_fn)
self.hf_config = hf_config
self.hf_config = hf_config
self.hf_text_config = get_hf_text_config(self.hf_config)
self.attention_chunk_size = getattr(self.hf_text_config,
"attention_chunk_size", None)
@ -2980,10 +2989,13 @@ class SpeculativeConfig:
"Chunked prefill and EAGLE are not compatible "
"when using V0.")
from vllm.transformers_utils.configs import (
SpeculatorsConfig)
from vllm.transformers_utils.configs.eagle import (
EAGLEConfig)
if isinstance(self.draft_model_config.hf_config,
EAGLEConfig):
(EAGLEConfig, SpeculatorsConfig)):
pass
else:
eagle_config = EAGLEConfig(

View File

@ -978,8 +978,28 @@ class EngineArgs:
provided as a JSON string input via CLI arguments or directly as a
dictionary from the engine.
"""
from vllm.transformers_utils.config import get_config
from vllm.transformers_utils.configs.speculators.base import (
SpeculatorsConfig)
if self.speculative_config is None:
return None
hf_config = get_config(self.hf_config_path or self.model,
self.trust_remote_code, self.revision,
self.code_revision, self.config_format)
# if loading a SpeculatorsConfig, load the specualtive_config
# details from the config directly
# no user input required / expected
if isinstance(hf_config, SpeculatorsConfig):
# We create one since we dont create one
self.speculative_config = {}
self.speculative_config[
"num_speculative_tokens"] = hf_config.num_lookahead_tokens
self.speculative_config["model"] = self.model
self.speculative_config["method"] = hf_config.method
else:
return None
# Note(Shangming): These parameters are not obtained from the cli arg
# '--speculative-config' and must be passed in when creating the engine

View File

@ -51,6 +51,25 @@ class LlamaDecoderLayer(LlamaDecoderLayer):
self.hidden_norm = RMSNorm(config.hidden_size, eps=config.rms_norm_eps)
if getattr(config, "norm_before_residual", False):
self._residual_norm = self._norm_before_residual
else:
self._residual_norm = self._norm_after_residual
def _norm_before_residual(
self,
hidden_states: torch.Tensor) -> tuple[torch.Tensor, torch.Tensor]:
hidden_states = self.hidden_norm(hidden_states)
residual = hidden_states
return hidden_states, residual
def _norm_after_residual(
self,
hidden_states: torch.Tensor) -> tuple[torch.Tensor, torch.Tensor]:
residual = hidden_states
hidden_states = self.hidden_norm(hidden_states)
return hidden_states, residual
def forward(
self,
positions: torch.Tensor,
@ -59,9 +78,10 @@ class LlamaDecoderLayer(LlamaDecoderLayer):
residual: Optional[torch.Tensor],
) -> tuple[torch.Tensor, torch.Tensor]:
residual = hidden_states
embeds = self.input_layernorm(embeds)
hidden_states = self.hidden_norm(hidden_states)
hidden_states, residual = self._residual_norm(
hidden_states=hidden_states)
hidden_states = torch.cat([embeds, hidden_states], dim=-1)
# Self Attention
@ -102,7 +122,7 @@ class LlamaModel(nn.Module):
self.layers = nn.ModuleList([
LlamaDecoderLayer(
self.config,
config=self.config,
prefix=maybe_prefix(prefix, f"layers.{start_layer_id}"),
)
])

View File

@ -35,8 +35,9 @@ from vllm.transformers_utils.configs import (ChatGLMConfig, DeepseekVLV2Config,
MllamaConfig, MLPSpeculatorConfig,
Nemotron_Nano_VL_Config,
NemotronConfig, NVLM_D_Config,
RWConfig, Step3TextConfig,
Step3VLConfig, UltravoxConfig)
RWConfig, SpeculatorsConfig,
Step3TextConfig, Step3VLConfig,
UltravoxConfig)
# yapf: enable
from vllm.transformers_utils.configs.mistral import adapt_config_dict
from vllm.transformers_utils.utils import check_gguf_file
@ -81,6 +82,7 @@ _CONFIG_REGISTRY: dict[str, type[PretrainedConfig]] = {
"mlp_speculator": MLPSpeculatorConfig,
"medusa": MedusaConfig,
"eagle": EAGLEConfig,
"speculators": SpeculatorsConfig,
"nemotron": NemotronConfig,
"NVLM_D": NVLM_D_Config,
"ultravox": UltravoxConfig,
@ -287,6 +289,27 @@ def _maybe_remap_hf_config_attrs(config: PretrainedConfig) -> PretrainedConfig:
return config
def maybe_override_with_speculators_target_model(
model: str,
tokenizer: str,
trust_remote_code: bool,
revision: Optional[str] = None) -> tuple[str, str]:
"""
If running a speculators config, override running model with target model
"""
config_dict, _ = PretrainedConfig.get_config_dict(
model,
revision=revision,
trust_remote_code=trust_remote_code,
token=_get_hf_token(),
)
spec_config = config_dict.get("speculators_config")
# Return the target model
if spec_config is not None:
model = tokenizer = spec_config["verifier"]["name_or_path"]
return model, tokenizer
def get_config(
model: Union[str, Path],
trust_remote_code: bool,
@ -345,9 +368,12 @@ def get_config(
token=_get_hf_token(),
**kwargs,
)
# Use custom model class if it's in our registry
model_type = config_dict.get("model_type")
if model_type is None:
model_type = "speculators" if config_dict.get(
"speculators_config") is not None else model_type
if model_type in _CONFIG_REGISTRY:
config_class = _CONFIG_REGISTRY[model_type]
config = config_class.from_pretrained(

View File

@ -24,6 +24,7 @@ from vllm.transformers_utils.configs.nemotron import NemotronConfig
from vllm.transformers_utils.configs.nemotron_h import NemotronHConfig
from vllm.transformers_utils.configs.nemotron_vl import Nemotron_Nano_VL_Config
from vllm.transformers_utils.configs.nvlm_d import NVLM_D_Config
from vllm.transformers_utils.configs.speculators.base import SpeculatorsConfig
from vllm.transformers_utils.configs.step3_vl import (Step3TextConfig,
Step3VisionEncoderConfig,
Step3VLConfig)
@ -44,6 +45,7 @@ __all__ = [
"NemotronHConfig",
"Nemotron_Nano_VL_Config",
"NVLM_D_Config",
"SpeculatorsConfig",
"UltravoxConfig",
"Step3VLConfig",
"Step3VisionEncoderConfig",

View File

@ -0,0 +1,2 @@
# SPDX-License-Identifier: Apache-2.0
# SPDX-FileCopyrightText: Copyright contributors to the vLLM project

View File

@ -0,0 +1,32 @@
# SPDX-License-Identifier: Apache-2.0
# SPDX-FileCopyrightText: Copyright contributors to the vLLM project
SUPPORTED_SPECULATORS_TYPES = {}
def register_speculator(name):
def decorator(fn):
SUPPORTED_SPECULATORS_TYPES[name] = fn
return fn
return decorator
@register_speculator("eagle3")
def update_eagle3(config_dict: dict, vllm_config: dict) -> None:
"""
Apply Eagle-3 specific configuration transformations.
Eagle-3 specific fields:
- draft_vocab_size: Size of the draft model's vocabulary
- target_hidden_size: Hidden size of the target model
- norm_before_residual: Whether to apply norm before residual connection
"""
vllm_config["draft_vocab_size"] = config_dict.get("draft_vocab_size")
if config_dict.get("target_hidden_size") is not None:
vllm_config["target_hidden_size"] = config_dict["target_hidden_size"]
vllm_config["norm_before_residual"] = config_dict.get(
"norm_before_residual", True)
vllm_config["architectures"] = ["Eagle3LlamaForCausalLM"]

View File

@ -0,0 +1,91 @@
# SPDX-License-Identifier: Apache-2.0
# SPDX-FileCopyrightText: Copyright contributors to the vLLM project
import os
from typing import Any, Union
from transformers import PretrainedConfig
from vllm.transformers_utils.configs.speculators.algos import (
SUPPORTED_SPECULATORS_TYPES)
__all__ = ["SpeculatorsConfig"]
class SpeculatorsConfig(PretrainedConfig):
model_type = "speculators"
@classmethod
def from_pretrained(
cls,
pretrained_model_name_or_path: Union[str, os.PathLike],
**kwargs,
) -> "SpeculatorsConfig":
"""Load speculators Eagle config and convert to vLLM format."""
config_dict, _ = cls.get_config_dict(pretrained_model_name_or_path,
**kwargs)
speculators_model_type = config_dict.get("speculators_model_type")
if speculators_model_type not in SUPPORTED_SPECULATORS_TYPES:
raise ValueError(
f"Expected one of: {SUPPORTED_SPECULATORS_TYPES}. "
"Please ensure you're loading a speculators-format model.")
# validate fields
# TODO: @dsikka - use speculators pydantic model to validate
cls.validate_speculators_config(config_dict=config_dict)
# Convert from speculators config -> format that can be ingested by vLLM
vllm_config = cls.convert_speculators_to_vllm(config_dict=config_dict)
# Apply anything specific to the supported algorithm
algo_updater = SUPPORTED_SPECULATORS_TYPES[speculators_model_type]
algo_updater(config_dict=config_dict, vllm_config=vllm_config)
return cls(**vllm_config)
@classmethod
def validate_speculators_config(cls, config_dict: dict[str, Any]) -> None:
try:
spec_config = config_dict["speculators_config"]
methods = spec_config["proposal_methods"]
first_method = methods[0]
_ = first_method["speculative_tokens"]
_ = spec_config["verifier"]["name_or_path"]
_ = config_dict["speculators_model_type"]
except (KeyError, IndexError, TypeError) as e:
raise ValueError("Invalid speculators config structure") from e
if "transformer_layer_config" not in config_dict:
raise ValueError("Must provide transformer_layer_config")
if not isinstance(config_dict["transformer_layer_config"], dict):
raise TypeError(
"'transformer_layer_config' must be a dictionary if provided")
@classmethod
def convert_speculators_to_vllm(
cls, config_dict: dict[str, Any]) -> dict[str, Any]:
"""
Convert speculators config format to vLLM format.
This method handles the translation of field names and structure
between speculators and vLLM formats.
Returns:
Dictionary with vLLM-compatible configuration
"""
# Currently we only support one proposal method
spec_config = config_dict["speculators_config"]
first_method = spec_config.get("proposal_methods")[0]
num_lookahead_tokens = first_method.get("speculative_tokens")
if num_lookahead_tokens is None:
raise ValueError(
"Missing 'speculative_tokens' in proposal method. "
f"Got: {first_method}")
# Build base vLLM config
vllm_config = {
"method": config_dict.get("speculators_model_type"),
"num_lookahead_tokens": num_lookahead_tokens,
"target_model": spec_config.get("verifier")["name_or_path"]
}
vllm_config.update(config_dict["transformer_layer_config"])
return vllm_config