mirror of
https://git.datalinker.icu/vllm-project/vllm.git
synced 2026-01-03 06:15:16 +08:00
[Speculative Decoding] Add speculators config support (#21345)
This commit is contained in:
parent
87c94bc879
commit
dfbc1f8880
16
tests/speculative_decoding/speculators/test_eagle3.py
Normal file
16
tests/speculative_decoding/speculators/test_eagle3.py
Normal file
@ -0,0 +1,16 @@
|
||||
# SPDX-License-Identifier: Apache-2.0
|
||||
# SPDX-FileCopyrightText: Copyright contributors to the vLLM project
|
||||
import pytest
|
||||
import torch
|
||||
|
||||
|
||||
@pytest.mark.parametrize(
|
||||
"model_path",
|
||||
[("nm-testing/SpeculatorLlama3-1-8B-Eagle3-converted-0717"),
|
||||
("nm-testing/SpeculatorLlama3-1-8B-Eagle3-converted-0717-quantized")])
|
||||
def test_llama(vllm_runner, example_prompts, model_path):
|
||||
with vllm_runner(model_path, dtype=torch.bfloat16) as vllm_model:
|
||||
vllm_outputs = vllm_model.generate_greedy(example_prompts,
|
||||
max_tokens=20)
|
||||
print(vllm_outputs)
|
||||
assert vllm_outputs
|
||||
@ -39,8 +39,8 @@ from vllm.transformers_utils.config import (
|
||||
ConfigFormat, get_config, get_hf_image_processor_config,
|
||||
get_hf_text_config, get_pooling_config,
|
||||
get_sentence_transformer_tokenizer_config, is_encoder_decoder,
|
||||
try_get_generation_config, try_get_safetensors_metadata,
|
||||
try_get_tokenizer_config, uses_mrope)
|
||||
maybe_override_with_speculators_target_model, try_get_generation_config,
|
||||
try_get_safetensors_metadata, try_get_tokenizer_config, uses_mrope)
|
||||
from vllm.transformers_utils.s3_utils import S3Model
|
||||
from vllm.transformers_utils.utils import is_s3, maybe_model_redirect
|
||||
# yapf conflicts with isort for this block
|
||||
@ -535,6 +535,15 @@ class ModelConfig:
|
||||
"affect the random state of the Python process that "
|
||||
"launched vLLM.", self.seed)
|
||||
|
||||
if self.runner != "draft":
|
||||
# If we're not running the draft model, check for speculators config
|
||||
# If speculators config, set model / tokenizer to be target model
|
||||
self.model, self.tokenizer = maybe_override_with_speculators_target_model( # noqa: E501
|
||||
model=self.model,
|
||||
tokenizer=self.tokenizer,
|
||||
revision=self.revision,
|
||||
trust_remote_code=self.trust_remote_code)
|
||||
|
||||
# Keep set served_model_name before maybe_model_redirect(self.model)
|
||||
self.served_model_name = get_served_model_name(self.model,
|
||||
self.served_model_name)
|
||||
@ -606,8 +615,8 @@ class ModelConfig:
|
||||
self.config_format,
|
||||
hf_overrides_kw=hf_overrides_kw,
|
||||
hf_overrides_fn=hf_overrides_fn)
|
||||
self.hf_config = hf_config
|
||||
|
||||
self.hf_config = hf_config
|
||||
self.hf_text_config = get_hf_text_config(self.hf_config)
|
||||
self.attention_chunk_size = getattr(self.hf_text_config,
|
||||
"attention_chunk_size", None)
|
||||
@ -2980,10 +2989,13 @@ class SpeculativeConfig:
|
||||
"Chunked prefill and EAGLE are not compatible "
|
||||
"when using V0.")
|
||||
|
||||
from vllm.transformers_utils.configs import (
|
||||
SpeculatorsConfig)
|
||||
from vllm.transformers_utils.configs.eagle import (
|
||||
EAGLEConfig)
|
||||
|
||||
if isinstance(self.draft_model_config.hf_config,
|
||||
EAGLEConfig):
|
||||
(EAGLEConfig, SpeculatorsConfig)):
|
||||
pass
|
||||
else:
|
||||
eagle_config = EAGLEConfig(
|
||||
|
||||
@ -978,8 +978,28 @@ class EngineArgs:
|
||||
provided as a JSON string input via CLI arguments or directly as a
|
||||
dictionary from the engine.
|
||||
"""
|
||||
|
||||
from vllm.transformers_utils.config import get_config
|
||||
from vllm.transformers_utils.configs.speculators.base import (
|
||||
SpeculatorsConfig)
|
||||
|
||||
if self.speculative_config is None:
|
||||
return None
|
||||
hf_config = get_config(self.hf_config_path or self.model,
|
||||
self.trust_remote_code, self.revision,
|
||||
self.code_revision, self.config_format)
|
||||
|
||||
# if loading a SpeculatorsConfig, load the specualtive_config
|
||||
# details from the config directly
|
||||
# no user input required / expected
|
||||
if isinstance(hf_config, SpeculatorsConfig):
|
||||
# We create one since we dont create one
|
||||
self.speculative_config = {}
|
||||
self.speculative_config[
|
||||
"num_speculative_tokens"] = hf_config.num_lookahead_tokens
|
||||
self.speculative_config["model"] = self.model
|
||||
self.speculative_config["method"] = hf_config.method
|
||||
else:
|
||||
return None
|
||||
|
||||
# Note(Shangming): These parameters are not obtained from the cli arg
|
||||
# '--speculative-config' and must be passed in when creating the engine
|
||||
|
||||
@ -51,6 +51,25 @@ class LlamaDecoderLayer(LlamaDecoderLayer):
|
||||
|
||||
self.hidden_norm = RMSNorm(config.hidden_size, eps=config.rms_norm_eps)
|
||||
|
||||
if getattr(config, "norm_before_residual", False):
|
||||
self._residual_norm = self._norm_before_residual
|
||||
else:
|
||||
self._residual_norm = self._norm_after_residual
|
||||
|
||||
def _norm_before_residual(
|
||||
self,
|
||||
hidden_states: torch.Tensor) -> tuple[torch.Tensor, torch.Tensor]:
|
||||
hidden_states = self.hidden_norm(hidden_states)
|
||||
residual = hidden_states
|
||||
return hidden_states, residual
|
||||
|
||||
def _norm_after_residual(
|
||||
self,
|
||||
hidden_states: torch.Tensor) -> tuple[torch.Tensor, torch.Tensor]:
|
||||
residual = hidden_states
|
||||
hidden_states = self.hidden_norm(hidden_states)
|
||||
return hidden_states, residual
|
||||
|
||||
def forward(
|
||||
self,
|
||||
positions: torch.Tensor,
|
||||
@ -59,9 +78,10 @@ class LlamaDecoderLayer(LlamaDecoderLayer):
|
||||
residual: Optional[torch.Tensor],
|
||||
) -> tuple[torch.Tensor, torch.Tensor]:
|
||||
|
||||
residual = hidden_states
|
||||
embeds = self.input_layernorm(embeds)
|
||||
hidden_states = self.hidden_norm(hidden_states)
|
||||
|
||||
hidden_states, residual = self._residual_norm(
|
||||
hidden_states=hidden_states)
|
||||
|
||||
hidden_states = torch.cat([embeds, hidden_states], dim=-1)
|
||||
# Self Attention
|
||||
@ -102,7 +122,7 @@ class LlamaModel(nn.Module):
|
||||
|
||||
self.layers = nn.ModuleList([
|
||||
LlamaDecoderLayer(
|
||||
self.config,
|
||||
config=self.config,
|
||||
prefix=maybe_prefix(prefix, f"layers.{start_layer_id}"),
|
||||
)
|
||||
])
|
||||
|
||||
@ -35,8 +35,9 @@ from vllm.transformers_utils.configs import (ChatGLMConfig, DeepseekVLV2Config,
|
||||
MllamaConfig, MLPSpeculatorConfig,
|
||||
Nemotron_Nano_VL_Config,
|
||||
NemotronConfig, NVLM_D_Config,
|
||||
RWConfig, Step3TextConfig,
|
||||
Step3VLConfig, UltravoxConfig)
|
||||
RWConfig, SpeculatorsConfig,
|
||||
Step3TextConfig, Step3VLConfig,
|
||||
UltravoxConfig)
|
||||
# yapf: enable
|
||||
from vllm.transformers_utils.configs.mistral import adapt_config_dict
|
||||
from vllm.transformers_utils.utils import check_gguf_file
|
||||
@ -81,6 +82,7 @@ _CONFIG_REGISTRY: dict[str, type[PretrainedConfig]] = {
|
||||
"mlp_speculator": MLPSpeculatorConfig,
|
||||
"medusa": MedusaConfig,
|
||||
"eagle": EAGLEConfig,
|
||||
"speculators": SpeculatorsConfig,
|
||||
"nemotron": NemotronConfig,
|
||||
"NVLM_D": NVLM_D_Config,
|
||||
"ultravox": UltravoxConfig,
|
||||
@ -287,6 +289,27 @@ def _maybe_remap_hf_config_attrs(config: PretrainedConfig) -> PretrainedConfig:
|
||||
return config
|
||||
|
||||
|
||||
def maybe_override_with_speculators_target_model(
|
||||
model: str,
|
||||
tokenizer: str,
|
||||
trust_remote_code: bool,
|
||||
revision: Optional[str] = None) -> tuple[str, str]:
|
||||
"""
|
||||
If running a speculators config, override running model with target model
|
||||
"""
|
||||
config_dict, _ = PretrainedConfig.get_config_dict(
|
||||
model,
|
||||
revision=revision,
|
||||
trust_remote_code=trust_remote_code,
|
||||
token=_get_hf_token(),
|
||||
)
|
||||
spec_config = config_dict.get("speculators_config")
|
||||
# Return the target model
|
||||
if spec_config is not None:
|
||||
model = tokenizer = spec_config["verifier"]["name_or_path"]
|
||||
return model, tokenizer
|
||||
|
||||
|
||||
def get_config(
|
||||
model: Union[str, Path],
|
||||
trust_remote_code: bool,
|
||||
@ -345,9 +368,12 @@ def get_config(
|
||||
token=_get_hf_token(),
|
||||
**kwargs,
|
||||
)
|
||||
|
||||
# Use custom model class if it's in our registry
|
||||
model_type = config_dict.get("model_type")
|
||||
if model_type is None:
|
||||
model_type = "speculators" if config_dict.get(
|
||||
"speculators_config") is not None else model_type
|
||||
|
||||
if model_type in _CONFIG_REGISTRY:
|
||||
config_class = _CONFIG_REGISTRY[model_type]
|
||||
config = config_class.from_pretrained(
|
||||
|
||||
@ -24,6 +24,7 @@ from vllm.transformers_utils.configs.nemotron import NemotronConfig
|
||||
from vllm.transformers_utils.configs.nemotron_h import NemotronHConfig
|
||||
from vllm.transformers_utils.configs.nemotron_vl import Nemotron_Nano_VL_Config
|
||||
from vllm.transformers_utils.configs.nvlm_d import NVLM_D_Config
|
||||
from vllm.transformers_utils.configs.speculators.base import SpeculatorsConfig
|
||||
from vllm.transformers_utils.configs.step3_vl import (Step3TextConfig,
|
||||
Step3VisionEncoderConfig,
|
||||
Step3VLConfig)
|
||||
@ -44,6 +45,7 @@ __all__ = [
|
||||
"NemotronHConfig",
|
||||
"Nemotron_Nano_VL_Config",
|
||||
"NVLM_D_Config",
|
||||
"SpeculatorsConfig",
|
||||
"UltravoxConfig",
|
||||
"Step3VLConfig",
|
||||
"Step3VisionEncoderConfig",
|
||||
|
||||
2
vllm/transformers_utils/configs/speculators/__init__.py
Normal file
2
vllm/transformers_utils/configs/speculators/__init__.py
Normal file
@ -0,0 +1,2 @@
|
||||
# SPDX-License-Identifier: Apache-2.0
|
||||
# SPDX-FileCopyrightText: Copyright contributors to the vLLM project
|
||||
32
vllm/transformers_utils/configs/speculators/algos.py
Normal file
32
vllm/transformers_utils/configs/speculators/algos.py
Normal file
@ -0,0 +1,32 @@
|
||||
# SPDX-License-Identifier: Apache-2.0
|
||||
# SPDX-FileCopyrightText: Copyright contributors to the vLLM project
|
||||
|
||||
SUPPORTED_SPECULATORS_TYPES = {}
|
||||
|
||||
|
||||
def register_speculator(name):
|
||||
|
||||
def decorator(fn):
|
||||
SUPPORTED_SPECULATORS_TYPES[name] = fn
|
||||
return fn
|
||||
|
||||
return decorator
|
||||
|
||||
|
||||
@register_speculator("eagle3")
|
||||
def update_eagle3(config_dict: dict, vllm_config: dict) -> None:
|
||||
"""
|
||||
Apply Eagle-3 specific configuration transformations.
|
||||
|
||||
Eagle-3 specific fields:
|
||||
- draft_vocab_size: Size of the draft model's vocabulary
|
||||
- target_hidden_size: Hidden size of the target model
|
||||
- norm_before_residual: Whether to apply norm before residual connection
|
||||
"""
|
||||
|
||||
vllm_config["draft_vocab_size"] = config_dict.get("draft_vocab_size")
|
||||
if config_dict.get("target_hidden_size") is not None:
|
||||
vllm_config["target_hidden_size"] = config_dict["target_hidden_size"]
|
||||
vllm_config["norm_before_residual"] = config_dict.get(
|
||||
"norm_before_residual", True)
|
||||
vllm_config["architectures"] = ["Eagle3LlamaForCausalLM"]
|
||||
91
vllm/transformers_utils/configs/speculators/base.py
Normal file
91
vllm/transformers_utils/configs/speculators/base.py
Normal file
@ -0,0 +1,91 @@
|
||||
# SPDX-License-Identifier: Apache-2.0
|
||||
# SPDX-FileCopyrightText: Copyright contributors to the vLLM project
|
||||
import os
|
||||
from typing import Any, Union
|
||||
|
||||
from transformers import PretrainedConfig
|
||||
|
||||
from vllm.transformers_utils.configs.speculators.algos import (
|
||||
SUPPORTED_SPECULATORS_TYPES)
|
||||
|
||||
__all__ = ["SpeculatorsConfig"]
|
||||
|
||||
|
||||
class SpeculatorsConfig(PretrainedConfig):
|
||||
model_type = "speculators"
|
||||
|
||||
@classmethod
|
||||
def from_pretrained(
|
||||
cls,
|
||||
pretrained_model_name_or_path: Union[str, os.PathLike],
|
||||
**kwargs,
|
||||
) -> "SpeculatorsConfig":
|
||||
"""Load speculators Eagle config and convert to vLLM format."""
|
||||
config_dict, _ = cls.get_config_dict(pretrained_model_name_or_path,
|
||||
**kwargs)
|
||||
|
||||
speculators_model_type = config_dict.get("speculators_model_type")
|
||||
if speculators_model_type not in SUPPORTED_SPECULATORS_TYPES:
|
||||
raise ValueError(
|
||||
f"Expected one of: {SUPPORTED_SPECULATORS_TYPES}. "
|
||||
"Please ensure you're loading a speculators-format model.")
|
||||
|
||||
# validate fields
|
||||
# TODO: @dsikka - use speculators pydantic model to validate
|
||||
cls.validate_speculators_config(config_dict=config_dict)
|
||||
# Convert from speculators config -> format that can be ingested by vLLM
|
||||
vllm_config = cls.convert_speculators_to_vllm(config_dict=config_dict)
|
||||
# Apply anything specific to the supported algorithm
|
||||
algo_updater = SUPPORTED_SPECULATORS_TYPES[speculators_model_type]
|
||||
algo_updater(config_dict=config_dict, vllm_config=vllm_config)
|
||||
return cls(**vllm_config)
|
||||
|
||||
@classmethod
|
||||
def validate_speculators_config(cls, config_dict: dict[str, Any]) -> None:
|
||||
try:
|
||||
spec_config = config_dict["speculators_config"]
|
||||
methods = spec_config["proposal_methods"]
|
||||
first_method = methods[0]
|
||||
_ = first_method["speculative_tokens"]
|
||||
_ = spec_config["verifier"]["name_or_path"]
|
||||
_ = config_dict["speculators_model_type"]
|
||||
except (KeyError, IndexError, TypeError) as e:
|
||||
raise ValueError("Invalid speculators config structure") from e
|
||||
|
||||
if "transformer_layer_config" not in config_dict:
|
||||
raise ValueError("Must provide transformer_layer_config")
|
||||
|
||||
if not isinstance(config_dict["transformer_layer_config"], dict):
|
||||
raise TypeError(
|
||||
"'transformer_layer_config' must be a dictionary if provided")
|
||||
|
||||
@classmethod
|
||||
def convert_speculators_to_vllm(
|
||||
cls, config_dict: dict[str, Any]) -> dict[str, Any]:
|
||||
"""
|
||||
Convert speculators config format to vLLM format.
|
||||
|
||||
This method handles the translation of field names and structure
|
||||
between speculators and vLLM formats.
|
||||
|
||||
Returns:
|
||||
Dictionary with vLLM-compatible configuration
|
||||
"""
|
||||
# Currently we only support one proposal method
|
||||
spec_config = config_dict["speculators_config"]
|
||||
first_method = spec_config.get("proposal_methods")[0]
|
||||
num_lookahead_tokens = first_method.get("speculative_tokens")
|
||||
|
||||
if num_lookahead_tokens is None:
|
||||
raise ValueError(
|
||||
"Missing 'speculative_tokens' in proposal method. "
|
||||
f"Got: {first_method}")
|
||||
|
||||
# Build base vLLM config
|
||||
vllm_config = {
|
||||
"method": config_dict.get("speculators_model_type"),
|
||||
"num_lookahead_tokens": num_lookahead_tokens,
|
||||
"target_model": spec_config.get("verifier")["name_or_path"]
|
||||
}
|
||||
vllm_config.update(config_dict["transformer_layer_config"])
|
||||
return vllm_config
|
||||
Loading…
x
Reference in New Issue
Block a user