mirror of
https://git.datalinker.icu/vllm-project/vllm.git
synced 2026-05-31 23:37:13 +08:00
feat: Enable engine-level arguments with speculators models (#25250)
Signed-off-by: Rahul Tuli <rtuli@redhat.com> Co-authored-by: Claude <noreply@anthropic.com> Signed-off-by: yewentao256 <zhyanwentao@126.com>
This commit is contained in:
parent
71f2b5ddea
commit
791089df20
@ -3,38 +3,52 @@
|
|||||||
import pytest
|
import pytest
|
||||||
import torch
|
import torch
|
||||||
|
|
||||||
|
from vllm.config import SpeculativeConfig
|
||||||
from vllm.model_executor.models.interfaces import supports_eagle3
|
from vllm.model_executor.models.interfaces import supports_eagle3
|
||||||
|
|
||||||
|
|
||||||
@pytest.mark.parametrize(
|
@pytest.mark.parametrize("model_path", [
|
||||||
"model_path",
|
pytest.param(
|
||||||
[("nm-testing/SpeculatorLlama3-1-8B-Eagle3-converted-0717-quantized")])
|
"nm-testing/SpeculatorLlama3-1-8B-Eagle3-converted-0717-quantized",
|
||||||
def test_llama(vllm_runner, example_prompts, model_path, monkeypatch):
|
id="llama3-eagle3-speculator"),
|
||||||
|
pytest.param(
|
||||||
|
"nm-testing/Speculator-Qwen3-8B-Eagle3-converted-071-quantized",
|
||||||
|
id="qwen3-eagle3-speculator"),
|
||||||
|
])
|
||||||
|
def test_eagle3_speculators_model(vllm_runner, example_prompts, model_path,
|
||||||
|
monkeypatch):
|
||||||
|
"""
|
||||||
|
Test Eagle3 speculators models properly initialize speculative decoding.
|
||||||
|
|
||||||
|
This test verifies:
|
||||||
|
1. Eagle3 support is detected for the model
|
||||||
|
2. Speculative config is automatically initialized from embedded config
|
||||||
|
3. The draft model path is correctly set to the speculators model
|
||||||
|
4. Speculative tokens count is valid
|
||||||
|
5. Text generation works with speculative decoding enabled
|
||||||
|
"""
|
||||||
# Set environment variable for V1 engine serialization
|
# Set environment variable for V1 engine serialization
|
||||||
monkeypatch.setenv("VLLM_ALLOW_INSECURE_SERIALIZATION", "1")
|
monkeypatch.setenv("VLLM_ALLOW_INSECURE_SERIALIZATION", "1")
|
||||||
|
|
||||||
with vllm_runner(model_path, dtype=torch.bfloat16) as vllm_model:
|
with vllm_runner(model_path, dtype=torch.bfloat16) as vllm_model:
|
||||||
|
# Verify Eagle3 support is detected
|
||||||
eagle3_supported = vllm_model.apply_model(supports_eagle3)
|
eagle3_supported = vllm_model.apply_model(supports_eagle3)
|
||||||
assert eagle3_supported
|
assert eagle3_supported, f"Eagle3 should be supported for {model_path}"
|
||||||
|
|
||||||
|
vllm_config = vllm_model.llm.llm_engine.vllm_config
|
||||||
|
|
||||||
|
assert isinstance(vllm_config.speculative_config, SpeculativeConfig), \
|
||||||
|
"Speculative config should be initialized for speculators model"
|
||||||
|
|
||||||
|
spec_config = vllm_config.speculative_config
|
||||||
|
assert spec_config.num_speculative_tokens > 0, \
|
||||||
|
(f"Expected positive speculative tokens, "
|
||||||
|
f"got {spec_config.num_speculative_tokens}")
|
||||||
|
|
||||||
|
assert spec_config.model == model_path, \
|
||||||
|
f"Draft model should be {model_path}, got {spec_config.model}"
|
||||||
|
|
||||||
vllm_outputs = vllm_model.generate_greedy(example_prompts,
|
vllm_outputs = vllm_model.generate_greedy(example_prompts,
|
||||||
max_tokens=20)
|
max_tokens=20)
|
||||||
print(vllm_outputs)
|
assert vllm_outputs, \
|
||||||
assert vllm_outputs
|
f"No outputs generated for speculators model {model_path}"
|
||||||
|
|
||||||
|
|
||||||
@pytest.mark.parametrize(
|
|
||||||
"model_path",
|
|
||||||
[("nm-testing/Speculator-Qwen3-8B-Eagle3-converted-071-quantized")])
|
|
||||||
def test_qwen(vllm_runner, example_prompts, model_path, monkeypatch):
|
|
||||||
# Set environment variable for V1 engine serialization
|
|
||||||
monkeypatch.setenv("VLLM_ALLOW_INSECURE_SERIALIZATION", "1")
|
|
||||||
|
|
||||||
with vllm_runner(model_path, dtype=torch.bfloat16) as vllm_model:
|
|
||||||
eagle3_supported = vllm_model.apply_model(supports_eagle3)
|
|
||||||
assert eagle3_supported
|
|
||||||
|
|
||||||
vllm_outputs = vllm_model.generate_greedy(example_prompts,
|
|
||||||
max_tokens=20)
|
|
||||||
print(vllm_outputs)
|
|
||||||
assert vllm_outputs
|
|
||||||
|
|||||||
@ -27,8 +27,7 @@ from vllm.transformers_utils.config import (
|
|||||||
ConfigFormat, get_config, get_hf_image_processor_config,
|
ConfigFormat, get_config, get_hf_image_processor_config,
|
||||||
get_hf_text_config, get_pooling_config,
|
get_hf_text_config, get_pooling_config,
|
||||||
get_sentence_transformer_tokenizer_config, is_encoder_decoder,
|
get_sentence_transformer_tokenizer_config, is_encoder_decoder,
|
||||||
is_interleaved, maybe_override_with_speculators_target_model,
|
is_interleaved, try_get_generation_config, try_get_safetensors_metadata,
|
||||||
try_get_generation_config, try_get_safetensors_metadata,
|
|
||||||
try_get_tokenizer_config, uses_mrope)
|
try_get_tokenizer_config, uses_mrope)
|
||||||
from vllm.transformers_utils.runai_utils import (ObjectStorageModel,
|
from vllm.transformers_utils.runai_utils import (ObjectStorageModel,
|
||||||
is_runai_obj_uri)
|
is_runai_obj_uri)
|
||||||
@ -416,15 +415,6 @@ class ModelConfig:
|
|||||||
|
|
||||||
self.maybe_pull_model_tokenizer_for_runai(self.model, self.tokenizer)
|
self.maybe_pull_model_tokenizer_for_runai(self.model, self.tokenizer)
|
||||||
|
|
||||||
if self.runner != "draft":
|
|
||||||
# If we're not running the draft model, check for speculators config
|
|
||||||
# If speculators config, set model / tokenizer to be target model
|
|
||||||
self.model, self.tokenizer = maybe_override_with_speculators_target_model( # noqa: E501
|
|
||||||
model=self.model,
|
|
||||||
tokenizer=self.tokenizer,
|
|
||||||
revision=self.revision,
|
|
||||||
trust_remote_code=self.trust_remote_code)
|
|
||||||
|
|
||||||
if (backend := envs.VLLM_ATTENTION_BACKEND
|
if (backend := envs.VLLM_ATTENTION_BACKEND
|
||||||
) and backend == "FLASHINFER" and find_spec("flashinfer") is None:
|
) and backend == "FLASHINFER" and find_spec("flashinfer") is None:
|
||||||
raise ValueError(
|
raise ValueError(
|
||||||
|
|||||||
@ -41,7 +41,8 @@ from vllm.plugins import load_general_plugins
|
|||||||
from vllm.ray.lazy_utils import is_ray_initialized
|
from vllm.ray.lazy_utils import is_ray_initialized
|
||||||
from vllm.reasoning import ReasoningParserManager
|
from vllm.reasoning import ReasoningParserManager
|
||||||
from vllm.test_utils import MODEL_WEIGHTS_S3_BUCKET, MODELS_ON_S3
|
from vllm.test_utils import MODEL_WEIGHTS_S3_BUCKET, MODELS_ON_S3
|
||||||
from vllm.transformers_utils.config import get_model_path, is_interleaved
|
from vllm.transformers_utils.config import (get_model_path, is_interleaved,
|
||||||
|
maybe_override_with_speculators)
|
||||||
from vllm.transformers_utils.utils import check_gguf_file
|
from vllm.transformers_utils.utils import check_gguf_file
|
||||||
from vllm.utils import (STR_DUAL_CHUNK_FLASH_ATTN_VAL, FlexibleArgumentParser,
|
from vllm.utils import (STR_DUAL_CHUNK_FLASH_ATTN_VAL, FlexibleArgumentParser,
|
||||||
GiB_bytes, get_ip, is_in_ray_actor)
|
GiB_bytes, get_ip, is_in_ray_actor)
|
||||||
@ -1082,29 +1083,8 @@ class EngineArgs:
|
|||||||
provided as a JSON string input via CLI arguments or directly as a
|
provided as a JSON string input via CLI arguments or directly as a
|
||||||
dictionary from the engine.
|
dictionary from the engine.
|
||||||
"""
|
"""
|
||||||
|
|
||||||
from vllm.transformers_utils.config import get_config
|
|
||||||
from vllm.transformers_utils.configs.speculators.base import (
|
|
||||||
SpeculatorsConfig)
|
|
||||||
|
|
||||||
if self.speculative_config is None:
|
if self.speculative_config is None:
|
||||||
hf_config = get_config(
|
return None
|
||||||
self.hf_config_path or target_model_config.model,
|
|
||||||
self.trust_remote_code, self.revision, self.code_revision,
|
|
||||||
self.config_format)
|
|
||||||
|
|
||||||
# if loading a SpeculatorsConfig, load the speculative_config
|
|
||||||
# details from the config directly
|
|
||||||
# no user input required / expected
|
|
||||||
if isinstance(hf_config, SpeculatorsConfig):
|
|
||||||
# We create one since we don't create one
|
|
||||||
self.speculative_config = {}
|
|
||||||
self.speculative_config[
|
|
||||||
"num_speculative_tokens"] = hf_config.num_lookahead_tokens
|
|
||||||
self.speculative_config["model"] = target_model_config.model
|
|
||||||
self.speculative_config["method"] = hf_config.method
|
|
||||||
else:
|
|
||||||
return None
|
|
||||||
|
|
||||||
# Note(Shangming): These parameters are not obtained from the cli arg
|
# Note(Shangming): These parameters are not obtained from the cli arg
|
||||||
# '--speculative-config' and must be passed in when creating the engine
|
# '--speculative-config' and must be passed in when creating the engine
|
||||||
@ -1139,6 +1119,15 @@ class EngineArgs:
|
|||||||
|
|
||||||
device_config = DeviceConfig(
|
device_config = DeviceConfig(
|
||||||
device=cast(Device, current_platform.device_type))
|
device=cast(Device, current_platform.device_type))
|
||||||
|
|
||||||
|
(self.model, self.tokenizer,
|
||||||
|
self.speculative_config) = maybe_override_with_speculators(
|
||||||
|
model=self.model,
|
||||||
|
tokenizer=self.tokenizer,
|
||||||
|
revision=self.revision,
|
||||||
|
trust_remote_code=self.trust_remote_code,
|
||||||
|
vllm_speculative_config=self.speculative_config,
|
||||||
|
)
|
||||||
model_config = self.create_model_config()
|
model_config = self.create_model_config()
|
||||||
|
|
||||||
# * If VLLM_USE_V1 is unset, we enable V1 for "supported features"
|
# * If VLLM_USE_V1 is unset, we enable V1 for "supported features"
|
||||||
|
|||||||
@ -463,15 +463,29 @@ def _maybe_remap_hf_config_attrs(config: PretrainedConfig) -> PretrainedConfig:
|
|||||||
return config
|
return config
|
||||||
|
|
||||||
|
|
||||||
def maybe_override_with_speculators_target_model(
|
def maybe_override_with_speculators(
|
||||||
model: str,
|
model: str,
|
||||||
tokenizer: str,
|
tokenizer: str,
|
||||||
trust_remote_code: bool,
|
trust_remote_code: bool,
|
||||||
revision: Optional[str] = None,
|
revision: Optional[str] = None,
|
||||||
|
vllm_speculative_config: Optional[dict[str, Any]] = None,
|
||||||
**kwargs,
|
**kwargs,
|
||||||
) -> tuple[str, str]:
|
) -> tuple[str, str, Optional[dict[str, Any]]]:
|
||||||
"""
|
"""
|
||||||
If running a speculators config, override running model with target model
|
Resolve model configuration when speculators are detected.
|
||||||
|
|
||||||
|
Checks if the provided model is a speculators model and if so, extracts
|
||||||
|
the target model configuration and builds the speculative config.
|
||||||
|
|
||||||
|
Args:
|
||||||
|
model: Model name or path
|
||||||
|
tokenizer: Tokenizer name or path
|
||||||
|
trust_remote_code: Whether to trust remote code
|
||||||
|
revision: Model revision
|
||||||
|
vllm_speculative_config: Existing vLLM speculative config
|
||||||
|
|
||||||
|
Returns:
|
||||||
|
Tuple of (resolved_model, resolved_tokenizer, speculative_config)
|
||||||
"""
|
"""
|
||||||
is_gguf = check_gguf_file(model)
|
is_gguf = check_gguf_file(model)
|
||||||
if is_gguf:
|
if is_gguf:
|
||||||
@ -487,11 +501,27 @@ def maybe_override_with_speculators_target_model(
|
|||||||
token=_get_hf_token(),
|
token=_get_hf_token(),
|
||||||
**kwargs,
|
**kwargs,
|
||||||
)
|
)
|
||||||
spec_config = config_dict.get("speculators_config", None)
|
speculators_config = config_dict.get("speculators_config")
|
||||||
# Return the target model
|
|
||||||
if spec_config is not None:
|
if speculators_config is None:
|
||||||
model = tokenizer = spec_config["verifier"]["name_or_path"]
|
# No speculators config found, return original values
|
||||||
return model, tokenizer
|
return model, tokenizer, vllm_speculative_config
|
||||||
|
|
||||||
|
# Speculators format detected - process overrides
|
||||||
|
from vllm.transformers_utils.configs.speculators.base import (
|
||||||
|
SpeculatorsConfig)
|
||||||
|
|
||||||
|
vllm_speculative_config = SpeculatorsConfig.extract_vllm_speculative_config(
|
||||||
|
config_dict=config_dict)
|
||||||
|
|
||||||
|
# Set the draft model to the speculators model
|
||||||
|
vllm_speculative_config["model"] = model
|
||||||
|
|
||||||
|
# Override model and tokenizer with the verifier model from config
|
||||||
|
verifier_model = speculators_config["verifier"]["name_or_path"]
|
||||||
|
model = tokenizer = verifier_model
|
||||||
|
|
||||||
|
return model, tokenizer, vllm_speculative_config
|
||||||
|
|
||||||
|
|
||||||
def get_config(
|
def get_config(
|
||||||
|
|||||||
@ -24,6 +24,12 @@ class SpeculatorsConfig(PretrainedConfig):
|
|||||||
config_dict, _ = cls.get_config_dict(pretrained_model_name_or_path,
|
config_dict, _ = cls.get_config_dict(pretrained_model_name_or_path,
|
||||||
**kwargs)
|
**kwargs)
|
||||||
|
|
||||||
|
vllm_config = cls.extract_vllm_speculative_config(config_dict)
|
||||||
|
return cls(**vllm_config)
|
||||||
|
|
||||||
|
@classmethod
|
||||||
|
def extract_vllm_speculative_config(
|
||||||
|
cls, config_dict: dict[str, Any]) -> dict[str, Any]:
|
||||||
speculators_model_type = config_dict.get("speculators_model_type")
|
speculators_model_type = config_dict.get("speculators_model_type")
|
||||||
if speculators_model_type not in SUPPORTED_SPECULATORS_TYPES:
|
if speculators_model_type not in SUPPORTED_SPECULATORS_TYPES:
|
||||||
raise ValueError(
|
raise ValueError(
|
||||||
@ -34,11 +40,12 @@ class SpeculatorsConfig(PretrainedConfig):
|
|||||||
# TODO: @dsikka - use speculators pydantic model to validate
|
# TODO: @dsikka - use speculators pydantic model to validate
|
||||||
cls.validate_speculators_config(config_dict=config_dict)
|
cls.validate_speculators_config(config_dict=config_dict)
|
||||||
# Convert from speculators config -> format that can be ingested by vLLM
|
# Convert from speculators config -> format that can be ingested by vLLM
|
||||||
vllm_config = cls.convert_speculators_to_vllm(config_dict=config_dict)
|
vllm_config = cls.build_vllm_speculative_config(
|
||||||
|
config_dict=config_dict)
|
||||||
# Apply anything specific to the supported algorithm
|
# Apply anything specific to the supported algorithm
|
||||||
algo_updater = SUPPORTED_SPECULATORS_TYPES[speculators_model_type]
|
algo_updater = SUPPORTED_SPECULATORS_TYPES[speculators_model_type]
|
||||||
algo_updater(config_dict=config_dict, vllm_config=vllm_config)
|
algo_updater(config_dict=config_dict, vllm_config=vllm_config)
|
||||||
return cls(**vllm_config)
|
return vllm_config
|
||||||
|
|
||||||
@classmethod
|
@classmethod
|
||||||
def validate_speculators_config(cls, config_dict: dict[str, Any]) -> None:
|
def validate_speculators_config(cls, config_dict: dict[str, Any]) -> None:
|
||||||
@ -60,32 +67,45 @@ class SpeculatorsConfig(PretrainedConfig):
|
|||||||
"'transformer_layer_config' must be a dictionary if provided")
|
"'transformer_layer_config' must be a dictionary if provided")
|
||||||
|
|
||||||
@classmethod
|
@classmethod
|
||||||
def convert_speculators_to_vllm(
|
def build_vllm_speculative_config(
|
||||||
cls, config_dict: dict[str, Any]) -> dict[str, Any]:
|
cls, config_dict: dict[str, Any]) -> dict[str, Any]:
|
||||||
"""
|
"""
|
||||||
Convert speculators config format to vLLM format.
|
Build vLLM-compatible speculative configuration from speculators format.
|
||||||
|
|
||||||
This method handles the translation of field names and structure
|
This method extracts and transforms speculative configuration from the
|
||||||
between speculators and vLLM formats.
|
speculators format into the structure expected by vLLM.
|
||||||
|
|
||||||
|
Args:
|
||||||
|
config_dict: Configuration dictionary in speculators format
|
||||||
|
|
||||||
Returns:
|
Returns:
|
||||||
Dictionary with vLLM-compatible configuration
|
Dictionary with vLLM-compatible speculative configuration
|
||||||
"""
|
"""
|
||||||
# Currently we only support one proposal method
|
# Extract speculators configuration
|
||||||
spec_config = config_dict["speculators_config"]
|
spec_config = config_dict["speculators_config"]
|
||||||
first_method = spec_config.get("proposal_methods")[0]
|
|
||||||
num_lookahead_tokens = first_method.get("speculative_tokens")
|
|
||||||
|
|
||||||
if num_lookahead_tokens is None:
|
# Currently we only support one proposal method
|
||||||
|
proposal_methods = spec_config.get("proposal_methods")
|
||||||
|
if not proposal_methods:
|
||||||
|
raise ValueError("No proposal methods found in speculators config")
|
||||||
|
|
||||||
|
first_method = proposal_methods[0]
|
||||||
|
num_speculative_tokens = first_method.get("speculative_tokens")
|
||||||
|
|
||||||
|
if num_speculative_tokens is None:
|
||||||
raise ValueError(
|
raise ValueError(
|
||||||
"Missing 'speculative_tokens' in proposal method. "
|
"Missing 'speculative_tokens' in proposal method. "
|
||||||
f"Got: {first_method}")
|
f"Got: {first_method}")
|
||||||
|
|
||||||
# Build base vLLM config
|
# Build base vLLM speculative configuration
|
||||||
vllm_config = {
|
vllm_config = {
|
||||||
"method": config_dict.get("speculators_model_type"),
|
"method": config_dict.get("speculators_model_type"),
|
||||||
"num_lookahead_tokens": num_lookahead_tokens,
|
"num_speculative_tokens": num_speculative_tokens,
|
||||||
"target_model": spec_config.get("verifier")["name_or_path"]
|
"target_model": spec_config.get("verifier")["name_or_path"]
|
||||||
}
|
}
|
||||||
vllm_config.update(config_dict["transformer_layer_config"])
|
|
||||||
|
# Merge transformer layer configuration if present
|
||||||
|
transformer_config = config_dict.get("transformer_layer_config", {})
|
||||||
|
vllm_config.update(transformer_config)
|
||||||
|
|
||||||
return vllm_config
|
return vllm_config
|
||||||
|
|||||||
Loading…
x
Reference in New Issue
Block a user