From 5a4b4b3729e1a1594bf56d38b7c8d3f556754634 Mon Sep 17 00:00:00 2001 From: Rahul Tuli Date: Tue, 12 Aug 2025 21:54:52 +0530 Subject: [PATCH] Add: `SupportsEagle3` interface for explicit EAGLE3 support (#22642) Signed-off-by: Rahul Tuli --- .../speculators/test_eagle3.py | 18 ++++++- vllm/model_executor/models/interfaces.py | 53 +++++++++++++++++++ vllm/model_executor/models/llama.py | 4 +- vllm/model_executor/models/qwen3.py | 4 +- vllm/v1/worker/gpu_model_runner.py | 10 +++- 5 files changed, 81 insertions(+), 8 deletions(-) diff --git a/tests/speculative_decoding/speculators/test_eagle3.py b/tests/speculative_decoding/speculators/test_eagle3.py index c46ac7a88b751..45ddb2178722a 100644 --- a/tests/speculative_decoding/speculators/test_eagle3.py +++ b/tests/speculative_decoding/speculators/test_eagle3.py @@ -3,12 +3,20 @@ import pytest import torch +from vllm.model_executor.models.interfaces import supports_eagle3 + @pytest.mark.parametrize( "model_path", [("nm-testing/SpeculatorLlama3-1-8B-Eagle3-converted-0717-quantized")]) -def test_llama(vllm_runner, example_prompts, model_path): +def test_llama(vllm_runner, example_prompts, model_path, monkeypatch): + # Set environment variable for V1 engine serialization + monkeypatch.setenv("VLLM_ALLOW_INSECURE_SERIALIZATION", "1") + with vllm_runner(model_path, dtype=torch.bfloat16) as vllm_model: + eagle3_supported = vllm_model.apply_model(supports_eagle3) + assert eagle3_supported + vllm_outputs = vllm_model.generate_greedy(example_prompts, max_tokens=20) print(vllm_outputs) @@ -18,8 +26,14 @@ def test_llama(vllm_runner, example_prompts, model_path): @pytest.mark.parametrize( "model_path", [("nm-testing/Speculator-Qwen3-8B-Eagle3-converted-071-quantized")]) -def test_qwen(vllm_runner, example_prompts, model_path): +def test_qwen(vllm_runner, example_prompts, model_path, monkeypatch): + # Set environment variable for V1 engine serialization + monkeypatch.setenv("VLLM_ALLOW_INSECURE_SERIALIZATION", "1") + with vllm_runner(model_path, dtype=torch.bfloat16) as vllm_model: + eagle3_supported = vllm_model.apply_model(supports_eagle3) + assert eagle3_supported + vllm_outputs = vllm_model.generate_greedy(example_prompts, max_tokens=20) print(vllm_outputs) diff --git a/vllm/model_executor/models/interfaces.py b/vllm/model_executor/models/interfaces.py index 46caf3fce4046..c425488f834b5 100644 --- a/vllm/model_executor/models/interfaces.py +++ b/vllm/model_executor/models/interfaces.py @@ -823,3 +823,56 @@ def supports_v0_only( model: Union[type[object], object], ) -> Union[TypeIs[type[SupportsV0Only]], TypeIs[SupportsV0Only]]: return getattr(model, "supports_v0_only", False) + + +@runtime_checkable +class SupportsEagle3(Protocol): + """The interface required for models that support + EAGLE3 speculative decoding.""" + + supports_eagle3: ClassVar[Literal[True]] = True + """ + A flag that indicates this model supports EAGLE3 + speculative decoding. + + Note: + There is no need to redefine this flag if this class is in the + MRO of your model class. + """ + + def set_aux_hidden_state_layers(self, layers: tuple[int, ...]) -> None: + """ + Set which layers should output auxiliary + hidden states for EAGLE3. + + Args: + layers: Tuple of layer indices that should output auxiliary + hidden states. + """ + ... + + def get_eagle3_aux_hidden_state_layers(self) -> tuple[int, ...]: + """ + Get the layer indices that should output auxiliary hidden states + for EAGLE3. + + Returns: + Tuple of layer indices for auxiliary hidden state outputs. + """ + ... + + +@overload +def supports_eagle3(model: type[object]) -> TypeIs[type[SupportsEagle3]]: + ... + + +@overload +def supports_eagle3(model: object) -> TypeIs[SupportsEagle3]: + ... + + +def supports_eagle3( + model: Union[type[object], object], +) -> Union[TypeIs[type[SupportsEagle3]], TypeIs[SupportsEagle3]]: + return isinstance(model, SupportsEagle3) diff --git a/vllm/model_executor/models/llama.py b/vllm/model_executor/models/llama.py index bc511d833908e..24cd448d8361f 100644 --- a/vllm/model_executor/models/llama.py +++ b/vllm/model_executor/models/llama.py @@ -49,7 +49,7 @@ from vllm.model_executor.model_loader.weight_utils import ( from vllm.model_executor.sampling_metadata import SamplingMetadata from vllm.sequence import IntermediateTensors -from .interfaces import SupportsLoRA, SupportsPP +from .interfaces import SupportsEagle3, SupportsLoRA, SupportsPP from .utils import (AutoWeightsLoader, PPMissingLayer, extract_layer_index, is_pp_missing_parameter, make_empty_intermediate_tensors_factory, make_layers, @@ -463,7 +463,7 @@ class LlamaModel(nn.Module): return loaded_params -class LlamaForCausalLM(nn.Module, SupportsLoRA, SupportsPP): +class LlamaForCausalLM(nn.Module, SupportsLoRA, SupportsPP, SupportsEagle3): packed_modules_mapping = { "qkv_proj": ["q_proj", "k_proj", "v_proj"], "gate_up_proj": ["gate_proj", "up_proj"] diff --git a/vllm/model_executor/models/qwen3.py b/vllm/model_executor/models/qwen3.py index 0ad50640bb3bc..2060206633702 100644 --- a/vllm/model_executor/models/qwen3.py +++ b/vllm/model_executor/models/qwen3.py @@ -44,7 +44,7 @@ from vllm.model_executor.layers.vocab_parallel_embedding import ParallelLMHead from vllm.model_executor.sampling_metadata import SamplingMetadata from vllm.sequence import IntermediateTensors -from .interfaces import SupportsLoRA, SupportsPP +from .interfaces import SupportsEagle3, SupportsLoRA, SupportsPP from .qwen2 import Qwen2MLP as Qwen3MLP from .qwen2 import Qwen2Model from .utils import (AutoWeightsLoader, PPMissingLayer, extract_layer_index, @@ -261,7 +261,7 @@ class Qwen3Model(Qwen2Model): decoder_layer_type=Qwen3DecoderLayer) -class Qwen3ForCausalLM(nn.Module, SupportsLoRA, SupportsPP): +class Qwen3ForCausalLM(nn.Module, SupportsLoRA, SupportsPP, SupportsEagle3): packed_modules_mapping = { "qkv_proj": [ "q_proj", diff --git a/vllm/v1/worker/gpu_model_runner.py b/vllm/v1/worker/gpu_model_runner.py index ed4d6bcb09d42..2e1cc37b1b761 100644 --- a/vllm/v1/worker/gpu_model_runner.py +++ b/vllm/v1/worker/gpu_model_runner.py @@ -35,6 +35,7 @@ from vllm.model_executor.layers.mamba.mamba_mixer2 import MambaBase from vllm.model_executor.layers.rotary_embedding import MRotaryEmbedding from vllm.model_executor.model_loader import TensorizerLoader, get_model_loader from vllm.model_executor.models.interfaces import (is_mixture_of_experts, + supports_eagle3, supports_transcription) from vllm.model_executor.models.interfaces_base import ( VllmModelForPooling, is_pooling_model, is_text_generation_model) @@ -1981,8 +1982,13 @@ class GPUModelRunner(LoRAModelRunnerMixin, KVConnectorModelRunnerMixin): logger.info("Loading drafter model...") self.drafter.load_model(self.model) if self.use_aux_hidden_state_outputs: - self.model.set_aux_hidden_state_layers( - self.model.get_eagle3_aux_hidden_state_layers()) + if supports_eagle3(self.model): + self.model.set_aux_hidden_state_layers( + self.model.get_eagle3_aux_hidden_state_layers()) + else: + raise RuntimeError( + "Model does not support EAGLE3 interface but " + "aux_hidden_state_outputs was requested") time_after_load = time.perf_counter() self.model_memory_usage = m.consumed_memory logger.info("Model loading took %.4f GiB and %.6f seconds",