mirror of
https://git.datalinker.icu/vllm-project/vllm.git
synced 2026-06-05 13:42:18 +08:00
[Speculative Decoding] Add speculators config support (#21345)
This commit is contained in:
parent
87c94bc879
commit
dfbc1f8880
16
tests/speculative_decoding/speculators/test_eagle3.py
Normal file
16
tests/speculative_decoding/speculators/test_eagle3.py
Normal file
@ -0,0 +1,16 @@
|
|||||||
|
# SPDX-License-Identifier: Apache-2.0
|
||||||
|
# SPDX-FileCopyrightText: Copyright contributors to the vLLM project
|
||||||
|
import pytest
|
||||||
|
import torch
|
||||||
|
|
||||||
|
|
||||||
|
@pytest.mark.parametrize(
|
||||||
|
"model_path",
|
||||||
|
[("nm-testing/SpeculatorLlama3-1-8B-Eagle3-converted-0717"),
|
||||||
|
("nm-testing/SpeculatorLlama3-1-8B-Eagle3-converted-0717-quantized")])
|
||||||
|
def test_llama(vllm_runner, example_prompts, model_path):
|
||||||
|
with vllm_runner(model_path, dtype=torch.bfloat16) as vllm_model:
|
||||||
|
vllm_outputs = vllm_model.generate_greedy(example_prompts,
|
||||||
|
max_tokens=20)
|
||||||
|
print(vllm_outputs)
|
||||||
|
assert vllm_outputs
|
||||||
@ -39,8 +39,8 @@ from vllm.transformers_utils.config import (
|
|||||||
ConfigFormat, get_config, get_hf_image_processor_config,
|
ConfigFormat, get_config, get_hf_image_processor_config,
|
||||||
get_hf_text_config, get_pooling_config,
|
get_hf_text_config, get_pooling_config,
|
||||||
get_sentence_transformer_tokenizer_config, is_encoder_decoder,
|
get_sentence_transformer_tokenizer_config, is_encoder_decoder,
|
||||||
try_get_generation_config, try_get_safetensors_metadata,
|
maybe_override_with_speculators_target_model, try_get_generation_config,
|
||||||
try_get_tokenizer_config, uses_mrope)
|
try_get_safetensors_metadata, try_get_tokenizer_config, uses_mrope)
|
||||||
from vllm.transformers_utils.s3_utils import S3Model
|
from vllm.transformers_utils.s3_utils import S3Model
|
||||||
from vllm.transformers_utils.utils import is_s3, maybe_model_redirect
|
from vllm.transformers_utils.utils import is_s3, maybe_model_redirect
|
||||||
# yapf conflicts with isort for this block
|
# yapf conflicts with isort for this block
|
||||||
@ -535,6 +535,15 @@ class ModelConfig:
|
|||||||
"affect the random state of the Python process that "
|
"affect the random state of the Python process that "
|
||||||
"launched vLLM.", self.seed)
|
"launched vLLM.", self.seed)
|
||||||
|
|
||||||
|
if self.runner != "draft":
|
||||||
|
# If we're not running the draft model, check for speculators config
|
||||||
|
# If speculators config, set model / tokenizer to be target model
|
||||||
|
self.model, self.tokenizer = maybe_override_with_speculators_target_model( # noqa: E501
|
||||||
|
model=self.model,
|
||||||
|
tokenizer=self.tokenizer,
|
||||||
|
revision=self.revision,
|
||||||
|
trust_remote_code=self.trust_remote_code)
|
||||||
|
|
||||||
# Keep set served_model_name before maybe_model_redirect(self.model)
|
# Keep set served_model_name before maybe_model_redirect(self.model)
|
||||||
self.served_model_name = get_served_model_name(self.model,
|
self.served_model_name = get_served_model_name(self.model,
|
||||||
self.served_model_name)
|
self.served_model_name)
|
||||||
@ -606,8 +615,8 @@ class ModelConfig:
|
|||||||
self.config_format,
|
self.config_format,
|
||||||
hf_overrides_kw=hf_overrides_kw,
|
hf_overrides_kw=hf_overrides_kw,
|
||||||
hf_overrides_fn=hf_overrides_fn)
|
hf_overrides_fn=hf_overrides_fn)
|
||||||
self.hf_config = hf_config
|
|
||||||
|
|
||||||
|
self.hf_config = hf_config
|
||||||
self.hf_text_config = get_hf_text_config(self.hf_config)
|
self.hf_text_config = get_hf_text_config(self.hf_config)
|
||||||
self.attention_chunk_size = getattr(self.hf_text_config,
|
self.attention_chunk_size = getattr(self.hf_text_config,
|
||||||
"attention_chunk_size", None)
|
"attention_chunk_size", None)
|
||||||
@ -2980,10 +2989,13 @@ class SpeculativeConfig:
|
|||||||
"Chunked prefill and EAGLE are not compatible "
|
"Chunked prefill and EAGLE are not compatible "
|
||||||
"when using V0.")
|
"when using V0.")
|
||||||
|
|
||||||
|
from vllm.transformers_utils.configs import (
|
||||||
|
SpeculatorsConfig)
|
||||||
from vllm.transformers_utils.configs.eagle import (
|
from vllm.transformers_utils.configs.eagle import (
|
||||||
EAGLEConfig)
|
EAGLEConfig)
|
||||||
|
|
||||||
if isinstance(self.draft_model_config.hf_config,
|
if isinstance(self.draft_model_config.hf_config,
|
||||||
EAGLEConfig):
|
(EAGLEConfig, SpeculatorsConfig)):
|
||||||
pass
|
pass
|
||||||
else:
|
else:
|
||||||
eagle_config = EAGLEConfig(
|
eagle_config = EAGLEConfig(
|
||||||
|
|||||||
@ -978,8 +978,28 @@ class EngineArgs:
|
|||||||
provided as a JSON string input via CLI arguments or directly as a
|
provided as a JSON string input via CLI arguments or directly as a
|
||||||
dictionary from the engine.
|
dictionary from the engine.
|
||||||
"""
|
"""
|
||||||
|
|
||||||
|
from vllm.transformers_utils.config import get_config
|
||||||
|
from vllm.transformers_utils.configs.speculators.base import (
|
||||||
|
SpeculatorsConfig)
|
||||||
|
|
||||||
if self.speculative_config is None:
|
if self.speculative_config is None:
|
||||||
return None
|
hf_config = get_config(self.hf_config_path or self.model,
|
||||||
|
self.trust_remote_code, self.revision,
|
||||||
|
self.code_revision, self.config_format)
|
||||||
|
|
||||||
|
# if loading a SpeculatorsConfig, load the specualtive_config
|
||||||
|
# details from the config directly
|
||||||
|
# no user input required / expected
|
||||||
|
if isinstance(hf_config, SpeculatorsConfig):
|
||||||
|
# We create one since we dont create one
|
||||||
|
self.speculative_config = {}
|
||||||
|
self.speculative_config[
|
||||||
|
"num_speculative_tokens"] = hf_config.num_lookahead_tokens
|
||||||
|
self.speculative_config["model"] = self.model
|
||||||
|
self.speculative_config["method"] = hf_config.method
|
||||||
|
else:
|
||||||
|
return None
|
||||||
|
|
||||||
# Note(Shangming): These parameters are not obtained from the cli arg
|
# Note(Shangming): These parameters are not obtained from the cli arg
|
||||||
# '--speculative-config' and must be passed in when creating the engine
|
# '--speculative-config' and must be passed in when creating the engine
|
||||||
|
|||||||
@ -51,6 +51,25 @@ class LlamaDecoderLayer(LlamaDecoderLayer):
|
|||||||
|
|
||||||
self.hidden_norm = RMSNorm(config.hidden_size, eps=config.rms_norm_eps)
|
self.hidden_norm = RMSNorm(config.hidden_size, eps=config.rms_norm_eps)
|
||||||
|
|
||||||
|
if getattr(config, "norm_before_residual", False):
|
||||||
|
self._residual_norm = self._norm_before_residual
|
||||||
|
else:
|
||||||
|
self._residual_norm = self._norm_after_residual
|
||||||
|
|
||||||
|
def _norm_before_residual(
|
||||||
|
self,
|
||||||
|
hidden_states: torch.Tensor) -> tuple[torch.Tensor, torch.Tensor]:
|
||||||
|
hidden_states = self.hidden_norm(hidden_states)
|
||||||
|
residual = hidden_states
|
||||||
|
return hidden_states, residual
|
||||||
|
|
||||||
|
def _norm_after_residual(
|
||||||
|
self,
|
||||||
|
hidden_states: torch.Tensor) -> tuple[torch.Tensor, torch.Tensor]:
|
||||||
|
residual = hidden_states
|
||||||
|
hidden_states = self.hidden_norm(hidden_states)
|
||||||
|
return hidden_states, residual
|
||||||
|
|
||||||
def forward(
|
def forward(
|
||||||
self,
|
self,
|
||||||
positions: torch.Tensor,
|
positions: torch.Tensor,
|
||||||
@ -59,9 +78,10 @@ class LlamaDecoderLayer(LlamaDecoderLayer):
|
|||||||
residual: Optional[torch.Tensor],
|
residual: Optional[torch.Tensor],
|
||||||
) -> tuple[torch.Tensor, torch.Tensor]:
|
) -> tuple[torch.Tensor, torch.Tensor]:
|
||||||
|
|
||||||
residual = hidden_states
|
|
||||||
embeds = self.input_layernorm(embeds)
|
embeds = self.input_layernorm(embeds)
|
||||||
hidden_states = self.hidden_norm(hidden_states)
|
|
||||||
|
hidden_states, residual = self._residual_norm(
|
||||||
|
hidden_states=hidden_states)
|
||||||
|
|
||||||
hidden_states = torch.cat([embeds, hidden_states], dim=-1)
|
hidden_states = torch.cat([embeds, hidden_states], dim=-1)
|
||||||
# Self Attention
|
# Self Attention
|
||||||
@ -102,7 +122,7 @@ class LlamaModel(nn.Module):
|
|||||||
|
|
||||||
self.layers = nn.ModuleList([
|
self.layers = nn.ModuleList([
|
||||||
LlamaDecoderLayer(
|
LlamaDecoderLayer(
|
||||||
self.config,
|
config=self.config,
|
||||||
prefix=maybe_prefix(prefix, f"layers.{start_layer_id}"),
|
prefix=maybe_prefix(prefix, f"layers.{start_layer_id}"),
|
||||||
)
|
)
|
||||||
])
|
])
|
||||||
|
|||||||
@ -35,8 +35,9 @@ from vllm.transformers_utils.configs import (ChatGLMConfig, DeepseekVLV2Config,
|
|||||||
MllamaConfig, MLPSpeculatorConfig,
|
MllamaConfig, MLPSpeculatorConfig,
|
||||||
Nemotron_Nano_VL_Config,
|
Nemotron_Nano_VL_Config,
|
||||||
NemotronConfig, NVLM_D_Config,
|
NemotronConfig, NVLM_D_Config,
|
||||||
RWConfig, Step3TextConfig,
|
RWConfig, SpeculatorsConfig,
|
||||||
Step3VLConfig, UltravoxConfig)
|
Step3TextConfig, Step3VLConfig,
|
||||||
|
UltravoxConfig)
|
||||||
# yapf: enable
|
# yapf: enable
|
||||||
from vllm.transformers_utils.configs.mistral import adapt_config_dict
|
from vllm.transformers_utils.configs.mistral import adapt_config_dict
|
||||||
from vllm.transformers_utils.utils import check_gguf_file
|
from vllm.transformers_utils.utils import check_gguf_file
|
||||||
@ -81,6 +82,7 @@ _CONFIG_REGISTRY: dict[str, type[PretrainedConfig]] = {
|
|||||||
"mlp_speculator": MLPSpeculatorConfig,
|
"mlp_speculator": MLPSpeculatorConfig,
|
||||||
"medusa": MedusaConfig,
|
"medusa": MedusaConfig,
|
||||||
"eagle": EAGLEConfig,
|
"eagle": EAGLEConfig,
|
||||||
|
"speculators": SpeculatorsConfig,
|
||||||
"nemotron": NemotronConfig,
|
"nemotron": NemotronConfig,
|
||||||
"NVLM_D": NVLM_D_Config,
|
"NVLM_D": NVLM_D_Config,
|
||||||
"ultravox": UltravoxConfig,
|
"ultravox": UltravoxConfig,
|
||||||
@ -287,6 +289,27 @@ def _maybe_remap_hf_config_attrs(config: PretrainedConfig) -> PretrainedConfig:
|
|||||||
return config
|
return config
|
||||||
|
|
||||||
|
|
||||||
|
def maybe_override_with_speculators_target_model(
|
||||||
|
model: str,
|
||||||
|
tokenizer: str,
|
||||||
|
trust_remote_code: bool,
|
||||||
|
revision: Optional[str] = None) -> tuple[str, str]:
|
||||||
|
"""
|
||||||
|
If running a speculators config, override running model with target model
|
||||||
|
"""
|
||||||
|
config_dict, _ = PretrainedConfig.get_config_dict(
|
||||||
|
model,
|
||||||
|
revision=revision,
|
||||||
|
trust_remote_code=trust_remote_code,
|
||||||
|
token=_get_hf_token(),
|
||||||
|
)
|
||||||
|
spec_config = config_dict.get("speculators_config")
|
||||||
|
# Return the target model
|
||||||
|
if spec_config is not None:
|
||||||
|
model = tokenizer = spec_config["verifier"]["name_or_path"]
|
||||||
|
return model, tokenizer
|
||||||
|
|
||||||
|
|
||||||
def get_config(
|
def get_config(
|
||||||
model: Union[str, Path],
|
model: Union[str, Path],
|
||||||
trust_remote_code: bool,
|
trust_remote_code: bool,
|
||||||
@ -345,9 +368,12 @@ def get_config(
|
|||||||
token=_get_hf_token(),
|
token=_get_hf_token(),
|
||||||
**kwargs,
|
**kwargs,
|
||||||
)
|
)
|
||||||
|
|
||||||
# Use custom model class if it's in our registry
|
# Use custom model class if it's in our registry
|
||||||
model_type = config_dict.get("model_type")
|
model_type = config_dict.get("model_type")
|
||||||
|
if model_type is None:
|
||||||
|
model_type = "speculators" if config_dict.get(
|
||||||
|
"speculators_config") is not None else model_type
|
||||||
|
|
||||||
if model_type in _CONFIG_REGISTRY:
|
if model_type in _CONFIG_REGISTRY:
|
||||||
config_class = _CONFIG_REGISTRY[model_type]
|
config_class = _CONFIG_REGISTRY[model_type]
|
||||||
config = config_class.from_pretrained(
|
config = config_class.from_pretrained(
|
||||||
|
|||||||
@ -24,6 +24,7 @@ from vllm.transformers_utils.configs.nemotron import NemotronConfig
|
|||||||
from vllm.transformers_utils.configs.nemotron_h import NemotronHConfig
|
from vllm.transformers_utils.configs.nemotron_h import NemotronHConfig
|
||||||
from vllm.transformers_utils.configs.nemotron_vl import Nemotron_Nano_VL_Config
|
from vllm.transformers_utils.configs.nemotron_vl import Nemotron_Nano_VL_Config
|
||||||
from vllm.transformers_utils.configs.nvlm_d import NVLM_D_Config
|
from vllm.transformers_utils.configs.nvlm_d import NVLM_D_Config
|
||||||
|
from vllm.transformers_utils.configs.speculators.base import SpeculatorsConfig
|
||||||
from vllm.transformers_utils.configs.step3_vl import (Step3TextConfig,
|
from vllm.transformers_utils.configs.step3_vl import (Step3TextConfig,
|
||||||
Step3VisionEncoderConfig,
|
Step3VisionEncoderConfig,
|
||||||
Step3VLConfig)
|
Step3VLConfig)
|
||||||
@ -44,6 +45,7 @@ __all__ = [
|
|||||||
"NemotronHConfig",
|
"NemotronHConfig",
|
||||||
"Nemotron_Nano_VL_Config",
|
"Nemotron_Nano_VL_Config",
|
||||||
"NVLM_D_Config",
|
"NVLM_D_Config",
|
||||||
|
"SpeculatorsConfig",
|
||||||
"UltravoxConfig",
|
"UltravoxConfig",
|
||||||
"Step3VLConfig",
|
"Step3VLConfig",
|
||||||
"Step3VisionEncoderConfig",
|
"Step3VisionEncoderConfig",
|
||||||
|
|||||||
2
vllm/transformers_utils/configs/speculators/__init__.py
Normal file
2
vllm/transformers_utils/configs/speculators/__init__.py
Normal file
@ -0,0 +1,2 @@
|
|||||||
|
# SPDX-License-Identifier: Apache-2.0
|
||||||
|
# SPDX-FileCopyrightText: Copyright contributors to the vLLM project
|
||||||
32
vllm/transformers_utils/configs/speculators/algos.py
Normal file
32
vllm/transformers_utils/configs/speculators/algos.py
Normal file
@ -0,0 +1,32 @@
|
|||||||
|
# SPDX-License-Identifier: Apache-2.0
|
||||||
|
# SPDX-FileCopyrightText: Copyright contributors to the vLLM project
|
||||||
|
|
||||||
|
SUPPORTED_SPECULATORS_TYPES = {}
|
||||||
|
|
||||||
|
|
||||||
|
def register_speculator(name):
|
||||||
|
|
||||||
|
def decorator(fn):
|
||||||
|
SUPPORTED_SPECULATORS_TYPES[name] = fn
|
||||||
|
return fn
|
||||||
|
|
||||||
|
return decorator
|
||||||
|
|
||||||
|
|
||||||
|
@register_speculator("eagle3")
|
||||||
|
def update_eagle3(config_dict: dict, vllm_config: dict) -> None:
|
||||||
|
"""
|
||||||
|
Apply Eagle-3 specific configuration transformations.
|
||||||
|
|
||||||
|
Eagle-3 specific fields:
|
||||||
|
- draft_vocab_size: Size of the draft model's vocabulary
|
||||||
|
- target_hidden_size: Hidden size of the target model
|
||||||
|
- norm_before_residual: Whether to apply norm before residual connection
|
||||||
|
"""
|
||||||
|
|
||||||
|
vllm_config["draft_vocab_size"] = config_dict.get("draft_vocab_size")
|
||||||
|
if config_dict.get("target_hidden_size") is not None:
|
||||||
|
vllm_config["target_hidden_size"] = config_dict["target_hidden_size"]
|
||||||
|
vllm_config["norm_before_residual"] = config_dict.get(
|
||||||
|
"norm_before_residual", True)
|
||||||
|
vllm_config["architectures"] = ["Eagle3LlamaForCausalLM"]
|
||||||
91
vllm/transformers_utils/configs/speculators/base.py
Normal file
91
vllm/transformers_utils/configs/speculators/base.py
Normal file
@ -0,0 +1,91 @@
|
|||||||
|
# SPDX-License-Identifier: Apache-2.0
|
||||||
|
# SPDX-FileCopyrightText: Copyright contributors to the vLLM project
|
||||||
|
import os
|
||||||
|
from typing import Any, Union
|
||||||
|
|
||||||
|
from transformers import PretrainedConfig
|
||||||
|
|
||||||
|
from vllm.transformers_utils.configs.speculators.algos import (
|
||||||
|
SUPPORTED_SPECULATORS_TYPES)
|
||||||
|
|
||||||
|
__all__ = ["SpeculatorsConfig"]
|
||||||
|
|
||||||
|
|
||||||
|
class SpeculatorsConfig(PretrainedConfig):
|
||||||
|
model_type = "speculators"
|
||||||
|
|
||||||
|
@classmethod
|
||||||
|
def from_pretrained(
|
||||||
|
cls,
|
||||||
|
pretrained_model_name_or_path: Union[str, os.PathLike],
|
||||||
|
**kwargs,
|
||||||
|
) -> "SpeculatorsConfig":
|
||||||
|
"""Load speculators Eagle config and convert to vLLM format."""
|
||||||
|
config_dict, _ = cls.get_config_dict(pretrained_model_name_or_path,
|
||||||
|
**kwargs)
|
||||||
|
|
||||||
|
speculators_model_type = config_dict.get("speculators_model_type")
|
||||||
|
if speculators_model_type not in SUPPORTED_SPECULATORS_TYPES:
|
||||||
|
raise ValueError(
|
||||||
|
f"Expected one of: {SUPPORTED_SPECULATORS_TYPES}. "
|
||||||
|
"Please ensure you're loading a speculators-format model.")
|
||||||
|
|
||||||
|
# validate fields
|
||||||
|
# TODO: @dsikka - use speculators pydantic model to validate
|
||||||
|
cls.validate_speculators_config(config_dict=config_dict)
|
||||||
|
# Convert from speculators config -> format that can be ingested by vLLM
|
||||||
|
vllm_config = cls.convert_speculators_to_vllm(config_dict=config_dict)
|
||||||
|
# Apply anything specific to the supported algorithm
|
||||||
|
algo_updater = SUPPORTED_SPECULATORS_TYPES[speculators_model_type]
|
||||||
|
algo_updater(config_dict=config_dict, vllm_config=vllm_config)
|
||||||
|
return cls(**vllm_config)
|
||||||
|
|
||||||
|
@classmethod
|
||||||
|
def validate_speculators_config(cls, config_dict: dict[str, Any]) -> None:
|
||||||
|
try:
|
||||||
|
spec_config = config_dict["speculators_config"]
|
||||||
|
methods = spec_config["proposal_methods"]
|
||||||
|
first_method = methods[0]
|
||||||
|
_ = first_method["speculative_tokens"]
|
||||||
|
_ = spec_config["verifier"]["name_or_path"]
|
||||||
|
_ = config_dict["speculators_model_type"]
|
||||||
|
except (KeyError, IndexError, TypeError) as e:
|
||||||
|
raise ValueError("Invalid speculators config structure") from e
|
||||||
|
|
||||||
|
if "transformer_layer_config" not in config_dict:
|
||||||
|
raise ValueError("Must provide transformer_layer_config")
|
||||||
|
|
||||||
|
if not isinstance(config_dict["transformer_layer_config"], dict):
|
||||||
|
raise TypeError(
|
||||||
|
"'transformer_layer_config' must be a dictionary if provided")
|
||||||
|
|
||||||
|
@classmethod
|
||||||
|
def convert_speculators_to_vllm(
|
||||||
|
cls, config_dict: dict[str, Any]) -> dict[str, Any]:
|
||||||
|
"""
|
||||||
|
Convert speculators config format to vLLM format.
|
||||||
|
|
||||||
|
This method handles the translation of field names and structure
|
||||||
|
between speculators and vLLM formats.
|
||||||
|
|
||||||
|
Returns:
|
||||||
|
Dictionary with vLLM-compatible configuration
|
||||||
|
"""
|
||||||
|
# Currently we only support one proposal method
|
||||||
|
spec_config = config_dict["speculators_config"]
|
||||||
|
first_method = spec_config.get("proposal_methods")[0]
|
||||||
|
num_lookahead_tokens = first_method.get("speculative_tokens")
|
||||||
|
|
||||||
|
if num_lookahead_tokens is None:
|
||||||
|
raise ValueError(
|
||||||
|
"Missing 'speculative_tokens' in proposal method. "
|
||||||
|
f"Got: {first_method}")
|
||||||
|
|
||||||
|
# Build base vLLM config
|
||||||
|
vllm_config = {
|
||||||
|
"method": config_dict.get("speculators_model_type"),
|
||||||
|
"num_lookahead_tokens": num_lookahead_tokens,
|
||||||
|
"target_model": spec_config.get("verifier")["name_or_path"]
|
||||||
|
}
|
||||||
|
vllm_config.update(config_dict["transformer_layer_config"])
|
||||||
|
return vllm_config
|
||||||
Loading…
x
Reference in New Issue
Block a user