feat: Enable engine-level arguments with speculators models (#25250)

Signed-off-by: Rahul Tuli <rtuli@redhat.com>
Co-authored-by: Claude <noreply@anthropic.com>
Signed-off-by: yewentao256 <zhyanwentao@126.com>
This commit is contained in:
Rahul Tuli 2025-09-21 22:34:45 +05:30 committed by yewentao256
parent 71f2b5ddea
commit 791089df20
5 changed files with 128 additions and 85 deletions

View File

@ -3,38 +3,52 @@
import pytest
import torch
from vllm.config import SpeculativeConfig
from vllm.model_executor.models.interfaces import supports_eagle3
@pytest.mark.parametrize(
"model_path",
[("nm-testing/SpeculatorLlama3-1-8B-Eagle3-converted-0717-quantized")])
def test_llama(vllm_runner, example_prompts, model_path, monkeypatch):
@pytest.mark.parametrize("model_path", [
pytest.param(
"nm-testing/SpeculatorLlama3-1-8B-Eagle3-converted-0717-quantized",
id="llama3-eagle3-speculator"),
pytest.param(
"nm-testing/Speculator-Qwen3-8B-Eagle3-converted-071-quantized",
id="qwen3-eagle3-speculator"),
])
def test_eagle3_speculators_model(vllm_runner, example_prompts, model_path,
monkeypatch):
"""
Test Eagle3 speculators models properly initialize speculative decoding.
This test verifies:
1. Eagle3 support is detected for the model
2. Speculative config is automatically initialized from embedded config
3. The draft model path is correctly set to the speculators model
4. Speculative tokens count is valid
5. Text generation works with speculative decoding enabled
"""
# Set environment variable for V1 engine serialization
monkeypatch.setenv("VLLM_ALLOW_INSECURE_SERIALIZATION", "1")
with vllm_runner(model_path, dtype=torch.bfloat16) as vllm_model:
# Verify Eagle3 support is detected
eagle3_supported = vllm_model.apply_model(supports_eagle3)
assert eagle3_supported
assert eagle3_supported, f"Eagle3 should be supported for {model_path}"
vllm_config = vllm_model.llm.llm_engine.vllm_config
assert isinstance(vllm_config.speculative_config, SpeculativeConfig), \
"Speculative config should be initialized for speculators model"
spec_config = vllm_config.speculative_config
assert spec_config.num_speculative_tokens > 0, \
(f"Expected positive speculative tokens, "
f"got {spec_config.num_speculative_tokens}")
assert spec_config.model == model_path, \
f"Draft model should be {model_path}, got {spec_config.model}"
vllm_outputs = vllm_model.generate_greedy(example_prompts,
max_tokens=20)
print(vllm_outputs)
assert vllm_outputs
@pytest.mark.parametrize(
"model_path",
[("nm-testing/Speculator-Qwen3-8B-Eagle3-converted-071-quantized")])
def test_qwen(vllm_runner, example_prompts, model_path, monkeypatch):
# Set environment variable for V1 engine serialization
monkeypatch.setenv("VLLM_ALLOW_INSECURE_SERIALIZATION", "1")
with vllm_runner(model_path, dtype=torch.bfloat16) as vllm_model:
eagle3_supported = vllm_model.apply_model(supports_eagle3)
assert eagle3_supported
vllm_outputs = vllm_model.generate_greedy(example_prompts,
max_tokens=20)
print(vllm_outputs)
assert vllm_outputs
assert vllm_outputs, \
f"No outputs generated for speculators model {model_path}"

View File

@ -27,8 +27,7 @@ from vllm.transformers_utils.config import (
ConfigFormat, get_config, get_hf_image_processor_config,
get_hf_text_config, get_pooling_config,
get_sentence_transformer_tokenizer_config, is_encoder_decoder,
is_interleaved, maybe_override_with_speculators_target_model,
try_get_generation_config, try_get_safetensors_metadata,
is_interleaved, try_get_generation_config, try_get_safetensors_metadata,
try_get_tokenizer_config, uses_mrope)
from vllm.transformers_utils.runai_utils import (ObjectStorageModel,
is_runai_obj_uri)
@ -416,15 +415,6 @@ class ModelConfig:
self.maybe_pull_model_tokenizer_for_runai(self.model, self.tokenizer)
if self.runner != "draft":
# If we're not running the draft model, check for speculators config
# If speculators config, set model / tokenizer to be target model
self.model, self.tokenizer = maybe_override_with_speculators_target_model( # noqa: E501
model=self.model,
tokenizer=self.tokenizer,
revision=self.revision,
trust_remote_code=self.trust_remote_code)
if (backend := envs.VLLM_ATTENTION_BACKEND
) and backend == "FLASHINFER" and find_spec("flashinfer") is None:
raise ValueError(

View File

@ -41,7 +41,8 @@ from vllm.plugins import load_general_plugins
from vllm.ray.lazy_utils import is_ray_initialized
from vllm.reasoning import ReasoningParserManager
from vllm.test_utils import MODEL_WEIGHTS_S3_BUCKET, MODELS_ON_S3
from vllm.transformers_utils.config import get_model_path, is_interleaved
from vllm.transformers_utils.config import (get_model_path, is_interleaved,
maybe_override_with_speculators)
from vllm.transformers_utils.utils import check_gguf_file
from vllm.utils import (STR_DUAL_CHUNK_FLASH_ATTN_VAL, FlexibleArgumentParser,
GiB_bytes, get_ip, is_in_ray_actor)
@ -1082,29 +1083,8 @@ class EngineArgs:
provided as a JSON string input via CLI arguments or directly as a
dictionary from the engine.
"""
from vllm.transformers_utils.config import get_config
from vllm.transformers_utils.configs.speculators.base import (
SpeculatorsConfig)
if self.speculative_config is None:
hf_config = get_config(
self.hf_config_path or target_model_config.model,
self.trust_remote_code, self.revision, self.code_revision,
self.config_format)
# if loading a SpeculatorsConfig, load the speculative_config
# details from the config directly
# no user input required / expected
if isinstance(hf_config, SpeculatorsConfig):
# We create one since we don't create one
self.speculative_config = {}
self.speculative_config[
"num_speculative_tokens"] = hf_config.num_lookahead_tokens
self.speculative_config["model"] = target_model_config.model
self.speculative_config["method"] = hf_config.method
else:
return None
return None
# Note(Shangming): These parameters are not obtained from the cli arg
# '--speculative-config' and must be passed in when creating the engine
@ -1139,6 +1119,15 @@ class EngineArgs:
device_config = DeviceConfig(
device=cast(Device, current_platform.device_type))
(self.model, self.tokenizer,
self.speculative_config) = maybe_override_with_speculators(
model=self.model,
tokenizer=self.tokenizer,
revision=self.revision,
trust_remote_code=self.trust_remote_code,
vllm_speculative_config=self.speculative_config,
)
model_config = self.create_model_config()
# * If VLLM_USE_V1 is unset, we enable V1 for "supported features"

View File

@ -463,15 +463,29 @@ def _maybe_remap_hf_config_attrs(config: PretrainedConfig) -> PretrainedConfig:
return config
def maybe_override_with_speculators_target_model(
def maybe_override_with_speculators(
model: str,
tokenizer: str,
trust_remote_code: bool,
revision: Optional[str] = None,
vllm_speculative_config: Optional[dict[str, Any]] = None,
**kwargs,
) -> tuple[str, str]:
) -> tuple[str, str, Optional[dict[str, Any]]]:
"""
If running a speculators config, override running model with target model
Resolve model configuration when speculators are detected.
Checks if the provided model is a speculators model and if so, extracts
the target model configuration and builds the speculative config.
Args:
model: Model name or path
tokenizer: Tokenizer name or path
trust_remote_code: Whether to trust remote code
revision: Model revision
vllm_speculative_config: Existing vLLM speculative config
Returns:
Tuple of (resolved_model, resolved_tokenizer, speculative_config)
"""
is_gguf = check_gguf_file(model)
if is_gguf:
@ -487,11 +501,27 @@ def maybe_override_with_speculators_target_model(
token=_get_hf_token(),
**kwargs,
)
spec_config = config_dict.get("speculators_config", None)
# Return the target model
if spec_config is not None:
model = tokenizer = spec_config["verifier"]["name_or_path"]
return model, tokenizer
speculators_config = config_dict.get("speculators_config")
if speculators_config is None:
# No speculators config found, return original values
return model, tokenizer, vllm_speculative_config
# Speculators format detected - process overrides
from vllm.transformers_utils.configs.speculators.base import (
SpeculatorsConfig)
vllm_speculative_config = SpeculatorsConfig.extract_vllm_speculative_config(
config_dict=config_dict)
# Set the draft model to the speculators model
vllm_speculative_config["model"] = model
# Override model and tokenizer with the verifier model from config
verifier_model = speculators_config["verifier"]["name_or_path"]
model = tokenizer = verifier_model
return model, tokenizer, vllm_speculative_config
def get_config(

View File

@ -24,6 +24,12 @@ class SpeculatorsConfig(PretrainedConfig):
config_dict, _ = cls.get_config_dict(pretrained_model_name_or_path,
**kwargs)
vllm_config = cls.extract_vllm_speculative_config(config_dict)
return cls(**vllm_config)
@classmethod
def extract_vllm_speculative_config(
cls, config_dict: dict[str, Any]) -> dict[str, Any]:
speculators_model_type = config_dict.get("speculators_model_type")
if speculators_model_type not in SUPPORTED_SPECULATORS_TYPES:
raise ValueError(
@ -34,11 +40,12 @@ class SpeculatorsConfig(PretrainedConfig):
# TODO: @dsikka - use speculators pydantic model to validate
cls.validate_speculators_config(config_dict=config_dict)
# Convert from speculators config -> format that can be ingested by vLLM
vllm_config = cls.convert_speculators_to_vllm(config_dict=config_dict)
vllm_config = cls.build_vllm_speculative_config(
config_dict=config_dict)
# Apply anything specific to the supported algorithm
algo_updater = SUPPORTED_SPECULATORS_TYPES[speculators_model_type]
algo_updater(config_dict=config_dict, vllm_config=vllm_config)
return cls(**vllm_config)
return vllm_config
@classmethod
def validate_speculators_config(cls, config_dict: dict[str, Any]) -> None:
@ -60,32 +67,45 @@ class SpeculatorsConfig(PretrainedConfig):
"'transformer_layer_config' must be a dictionary if provided")
@classmethod
def convert_speculators_to_vllm(
def build_vllm_speculative_config(
cls, config_dict: dict[str, Any]) -> dict[str, Any]:
"""
Convert speculators config format to vLLM format.
This method handles the translation of field names and structure
between speculators and vLLM formats.
Returns:
Dictionary with vLLM-compatible configuration
"""
# Currently we only support one proposal method
spec_config = config_dict["speculators_config"]
first_method = spec_config.get("proposal_methods")[0]
num_lookahead_tokens = first_method.get("speculative_tokens")
Build vLLM-compatible speculative configuration from speculators format.
if num_lookahead_tokens is None:
This method extracts and transforms speculative configuration from the
speculators format into the structure expected by vLLM.
Args:
config_dict: Configuration dictionary in speculators format
Returns:
Dictionary with vLLM-compatible speculative configuration
"""
# Extract speculators configuration
spec_config = config_dict["speculators_config"]
# Currently we only support one proposal method
proposal_methods = spec_config.get("proposal_methods")
if not proposal_methods:
raise ValueError("No proposal methods found in speculators config")
first_method = proposal_methods[0]
num_speculative_tokens = first_method.get("speculative_tokens")
if num_speculative_tokens is None:
raise ValueError(
"Missing 'speculative_tokens' in proposal method. "
f"Got: {first_method}")
# Build base vLLM config
# Build base vLLM speculative configuration
vllm_config = {
"method": config_dict.get("speculators_model_type"),
"num_lookahead_tokens": num_lookahead_tokens,
"num_speculative_tokens": num_speculative_tokens,
"target_model": spec_config.get("verifier")["name_or_path"]
}
vllm_config.update(config_dict["transformer_layer_config"])
# Merge transformer layer configuration if present
transformer_config = config_dict.get("transformer_layer_config", {})
vllm_config.update(transformer_config)
return vllm_config