Add: SupportsEagle3 interface for explicit EAGLE3 support (#22642)

Signed-off-by: Rahul Tuli <rtuli@redhat.com>
This commit is contained in:
Rahul Tuli 2025-08-12 21:54:52 +05:30 committed by GitHub
parent e5d3d63c42
commit 5a4b4b3729
No known key found for this signature in database
GPG Key ID: B5690EEEBB952194
5 changed files with 81 additions and 8 deletions

View File

@ -3,12 +3,20 @@
import pytest
import torch
from vllm.model_executor.models.interfaces import supports_eagle3
@pytest.mark.parametrize(
"model_path",
[("nm-testing/SpeculatorLlama3-1-8B-Eagle3-converted-0717-quantized")])
def test_llama(vllm_runner, example_prompts, model_path):
def test_llama(vllm_runner, example_prompts, model_path, monkeypatch):
# Set environment variable for V1 engine serialization
monkeypatch.setenv("VLLM_ALLOW_INSECURE_SERIALIZATION", "1")
with vllm_runner(model_path, dtype=torch.bfloat16) as vllm_model:
eagle3_supported = vllm_model.apply_model(supports_eagle3)
assert eagle3_supported
vllm_outputs = vllm_model.generate_greedy(example_prompts,
max_tokens=20)
print(vllm_outputs)
@ -18,8 +26,14 @@ def test_llama(vllm_runner, example_prompts, model_path):
@pytest.mark.parametrize(
"model_path",
[("nm-testing/Speculator-Qwen3-8B-Eagle3-converted-071-quantized")])
def test_qwen(vllm_runner, example_prompts, model_path):
def test_qwen(vllm_runner, example_prompts, model_path, monkeypatch):
# Set environment variable for V1 engine serialization
monkeypatch.setenv("VLLM_ALLOW_INSECURE_SERIALIZATION", "1")
with vllm_runner(model_path, dtype=torch.bfloat16) as vllm_model:
eagle3_supported = vllm_model.apply_model(supports_eagle3)
assert eagle3_supported
vllm_outputs = vllm_model.generate_greedy(example_prompts,
max_tokens=20)
print(vllm_outputs)

View File

@ -823,3 +823,56 @@ def supports_v0_only(
model: Union[type[object], object],
) -> Union[TypeIs[type[SupportsV0Only]], TypeIs[SupportsV0Only]]:
return getattr(model, "supports_v0_only", False)
@runtime_checkable
class SupportsEagle3(Protocol):
"""The interface required for models that support
EAGLE3 speculative decoding."""
supports_eagle3: ClassVar[Literal[True]] = True
"""
A flag that indicates this model supports EAGLE3
speculative decoding.
Note:
There is no need to redefine this flag if this class is in the
MRO of your model class.
"""
def set_aux_hidden_state_layers(self, layers: tuple[int, ...]) -> None:
"""
Set which layers should output auxiliary
hidden states for EAGLE3.
Args:
layers: Tuple of layer indices that should output auxiliary
hidden states.
"""
...
def get_eagle3_aux_hidden_state_layers(self) -> tuple[int, ...]:
"""
Get the layer indices that should output auxiliary hidden states
for EAGLE3.
Returns:
Tuple of layer indices for auxiliary hidden state outputs.
"""
...
@overload
def supports_eagle3(model: type[object]) -> TypeIs[type[SupportsEagle3]]:
...
@overload
def supports_eagle3(model: object) -> TypeIs[SupportsEagle3]:
...
def supports_eagle3(
model: Union[type[object], object],
) -> Union[TypeIs[type[SupportsEagle3]], TypeIs[SupportsEagle3]]:
return isinstance(model, SupportsEagle3)

View File

@ -49,7 +49,7 @@ from vllm.model_executor.model_loader.weight_utils import (
from vllm.model_executor.sampling_metadata import SamplingMetadata
from vllm.sequence import IntermediateTensors
from .interfaces import SupportsLoRA, SupportsPP
from .interfaces import SupportsEagle3, SupportsLoRA, SupportsPP
from .utils import (AutoWeightsLoader, PPMissingLayer, extract_layer_index,
is_pp_missing_parameter,
make_empty_intermediate_tensors_factory, make_layers,
@ -463,7 +463,7 @@ class LlamaModel(nn.Module):
return loaded_params
class LlamaForCausalLM(nn.Module, SupportsLoRA, SupportsPP):
class LlamaForCausalLM(nn.Module, SupportsLoRA, SupportsPP, SupportsEagle3):
packed_modules_mapping = {
"qkv_proj": ["q_proj", "k_proj", "v_proj"],
"gate_up_proj": ["gate_proj", "up_proj"]

View File

@ -44,7 +44,7 @@ from vllm.model_executor.layers.vocab_parallel_embedding import ParallelLMHead
from vllm.model_executor.sampling_metadata import SamplingMetadata
from vllm.sequence import IntermediateTensors
from .interfaces import SupportsLoRA, SupportsPP
from .interfaces import SupportsEagle3, SupportsLoRA, SupportsPP
from .qwen2 import Qwen2MLP as Qwen3MLP
from .qwen2 import Qwen2Model
from .utils import (AutoWeightsLoader, PPMissingLayer, extract_layer_index,
@ -261,7 +261,7 @@ class Qwen3Model(Qwen2Model):
decoder_layer_type=Qwen3DecoderLayer)
class Qwen3ForCausalLM(nn.Module, SupportsLoRA, SupportsPP):
class Qwen3ForCausalLM(nn.Module, SupportsLoRA, SupportsPP, SupportsEagle3):
packed_modules_mapping = {
"qkv_proj": [
"q_proj",

View File

@ -35,6 +35,7 @@ from vllm.model_executor.layers.mamba.mamba_mixer2 import MambaBase
from vllm.model_executor.layers.rotary_embedding import MRotaryEmbedding
from vllm.model_executor.model_loader import TensorizerLoader, get_model_loader
from vllm.model_executor.models.interfaces import (is_mixture_of_experts,
supports_eagle3,
supports_transcription)
from vllm.model_executor.models.interfaces_base import (
VllmModelForPooling, is_pooling_model, is_text_generation_model)
@ -1981,8 +1982,13 @@ class GPUModelRunner(LoRAModelRunnerMixin, KVConnectorModelRunnerMixin):
logger.info("Loading drafter model...")
self.drafter.load_model(self.model)
if self.use_aux_hidden_state_outputs:
self.model.set_aux_hidden_state_layers(
self.model.get_eagle3_aux_hidden_state_layers())
if supports_eagle3(self.model):
self.model.set_aux_hidden_state_layers(
self.model.get_eagle3_aux_hidden_state_layers())
else:
raise RuntimeError(
"Model does not support EAGLE3 interface but "
"aux_hidden_state_outputs was requested")
time_after_load = time.perf_counter()
self.model_memory_usage = m.consumed_memory
logger.info("Model loading took %.4f GiB and %.6f seconds",