From 8781cd6b88ad264a01886a05e698b5e036fb4eb9 Mon Sep 17 00:00:00 2001 From: Harry Mellor <19981378+hmellor@users.noreply.github.com> Date: Thu, 11 Dec 2025 17:02:10 +0000 Subject: [PATCH] Add Eagle and Eagle3 support to Transformers modeling backend (#30340) Signed-off-by: Harry Mellor <19981378+hmellor@users.noreply.github.com> --- tests/v1/e2e/test_spec_decode.py | 36 +++++++++- .../models/transformers/base.py | 66 +++++++++++++++++-- 2 files changed, 94 insertions(+), 8 deletions(-) diff --git a/tests/v1/e2e/test_spec_decode.py b/tests/v1/e2e/test_spec_decode.py index 8c904a8cddac4..c8587659d6580 100644 --- a/tests/v1/e2e/test_spec_decode.py +++ b/tests/v1/e2e/test_spec_decode.py @@ -280,9 +280,20 @@ def test_speculators_model_integration( @pytest.mark.parametrize( - ["model_setup", "mm_enabled", "enable_chunked_prefill"], + ["model_setup", "mm_enabled", "enable_chunked_prefill", "model_impl"], [ - (("eagle3", "Qwen/Qwen3-8B", "AngelSlim/Qwen3-8B_eagle3", 1), False, False), + ( + ("eagle3", "Qwen/Qwen3-8B", "AngelSlim/Qwen3-8B_eagle3", 1), + False, + False, + "auto", + ), + ( + ("eagle3", "Qwen/Qwen3-8B", "AngelSlim/Qwen3-8B_eagle3", 1), + False, + False, + "transformers", + ), pytest.param( ( "eagle3", @@ -292,6 +303,7 @@ def test_speculators_model_integration( ), False, False, + "auto", marks=pytest.mark.skip( reason="architecture of its eagle3 is LlamaForCausalLMEagle3" ), @@ -305,6 +317,7 @@ def test_speculators_model_integration( ), False, False, + "auto", marks=pytest.mark.skip( reason="Skipping due to its head_dim not being a a multiple of 32" ), @@ -318,6 +331,7 @@ def test_speculators_model_integration( ), False, True, + "auto", marks=large_gpu_mark(min_gb=40), ), # works on 4x H100 ( @@ -329,6 +343,7 @@ def test_speculators_model_integration( ), False, False, + "auto", ), pytest.param( ( @@ -339,6 +354,7 @@ def test_speculators_model_integration( ), False, False, + "auto", marks=large_gpu_mark(min_gb=80), ), # works on 4x H100 pytest.param( @@ -350,6 +366,7 @@ def test_speculators_model_integration( ), True, True, + "auto", marks=large_gpu_mark(min_gb=80), ), # works on 4x H100 ( @@ -361,10 +378,12 @@ def test_speculators_model_integration( ), False, False, + "auto", ), ], ids=[ "qwen3_eagle3", + "qwen3_eagle3-transformers", "qwen3_vl_eagle3", "qwen2_5_vl_eagle3", "llama3_eagle", @@ -381,6 +400,7 @@ def test_eagle_correctness( model_setup: tuple[str, str, str, int], mm_enabled: bool, enable_chunked_prefill: bool, + model_impl: str, attn_backend: str, ): if attn_backend == "TREE_ATTN": @@ -389,6 +409,17 @@ def test_eagle_correctness( "TREE_ATTN is flaky in the test disable for now until it can be " "resolved (see https://github.com/vllm-project/vllm/issues/22922)" ) + if model_impl == "transformers": + import transformers + from packaging.version import Version + + installed = Version(transformers.__version__) + required = Version("5.0.0.dev") + if installed < required: + pytest.skip( + "Eagle3 with the Transformers modeling backend requires " + f"transformers>={required}, but got {installed}" + ) # Generate test prompts inside the function instead of using fixture test_prompts = get_test_prompts(mm_enabled) @@ -448,6 +479,7 @@ def test_eagle_correctness( max_model_len=max_model_len, max_num_batched_tokens=max_num_batched_tokens, enable_chunked_prefill=enable_chunked_prefill, + model_impl=model_impl, ) spec_outputs = spec_llm.chat(test_prompts, sampling_config) matches = 0 diff --git a/vllm/model_executor/models/transformers/base.py b/vllm/model_executor/models/transformers/base.py index f3ebc6da8e302..45e746ac2d356 100644 --- a/vllm/model_executor/models/transformers/base.py +++ b/vllm/model_executor/models/transformers/base.py @@ -36,6 +36,8 @@ from vllm.distributed.utils import get_pp_indices from vllm.logger import init_logger from vllm.model_executor.layers.vocab_parallel_embedding import VocabParallelEmbedding from vllm.model_executor.models.interfaces import ( + SupportsEagle, + SupportsEagle3, SupportsLoRA, SupportsPP, SupportsQuant, @@ -92,7 +94,15 @@ def vllm_flash_attention_forward( ALL_ATTENTION_FUNCTIONS["vllm"] = vllm_flash_attention_forward -class Base(nn.Module, VllmModel, SupportsQuant, SupportsLoRA, SupportsPP): +class Base( + nn.Module, + VllmModel, + SupportsQuant, + SupportsLoRA, + SupportsPP, + SupportsEagle, + SupportsEagle3, +): embedding_modules = ["embed_tokens"] # TODO transformers will have a util to get it hf_to_vllm_mapper = WeightsMapper( orig_to_new_prefix={ @@ -131,17 +141,24 @@ class Base(nn.Module, VllmModel, SupportsQuant, SupportsLoRA, SupportsPP): self.pp_group = get_pp_group() self.tp_group = get_tp_group() - # Weights to skip in `self.load_weights` + # Attrs for weight loading (see self.load_weights) self.skip_prefixes: list[str] = [] """Skip loading weights whose qualname starts with these prefixes.""" self.skip_substrs: list[str] = [] """Skip loading weights whose qualname contains these substrings.""" self.ignore_unexpected_prefixes: list[str] = [] - """Ignore unexpected weights whose qualname starts with these prefixes. - """ + """Ignore unexpected weights whose qualname starts with these prefixes.""" self.ignore_unexpected_suffixes: list[str] = [] """Ignore unexpected weights whose qualname ends with these suffixes.""" + # Attrs for Eagle3 (see self.set_aux_hidden_state_layers) + self._target_class: type[nn.Module] = nn.Module + """Target class for Eagle3 aux hidden state recording.""" + self._layer_names: dict[int, str] = {} + """Mapping from layer index to layer name for Eagle3.""" + self._output_aux_hidden_states_kwargs: dict[str, bool] = {} + """Kwargs to pass to model forward for Eagle3 aux hidden states.""" + if self.quant_config: quant_method_name = self.quant_config.get_name() # Check for unsupported quantization methods. @@ -278,6 +295,15 @@ class Base(nn.Module, VllmModel, SupportsQuant, SupportsLoRA, SupportsPP): for child_name, child_module in module.named_children(): new_module = child_module qual_name = maybe_prefix(prefix, child_name) + # Populate Eagle3 attrs + if ( + isinstance(module, nn.ModuleList) + and len(module) == self.text_config.num_hidden_layers + ): + self._target_class = type(child_module) + layer_name = qual_name.removeprefix("model.") + self._layer_names[int(child_name)] = layer_name + # Replace modules as needed if isinstance(child_module, nn.Linear): generator = (p for p in tp_plan if re.match(p, qual_name)) pattern = next(generator, None) @@ -425,19 +451,26 @@ class Base(nn.Module, VllmModel, SupportsQuant, SupportsLoRA, SupportsPP): else: position_ids = positions[None, ...] - hidden_states = self.model( + outputs = self.model( input_ids=input_ids, inputs_embeds=inputs_embeds, use_cache=False, position_ids=position_ids, attention_instances=self.attention_instances, return_dict=False, + **self._output_aux_hidden_states_kwargs, **kwargs, - )[0][0, ...] # we remove batch dimension for now + ) + # We must remove the batch dimension from these outputs + hidden_states = outputs[0][0, ...] + if self._output_aux_hidden_states_kwargs: + aux_hidden_states = [x[0][0, ...] for x in outputs[1:]] if not self.pp_group.is_last_rank: return IntermediateTensors({"hidden_states": hidden_states}) + if self._output_aux_hidden_states_kwargs and len(aux_hidden_states) > 0: + return hidden_states, aux_hidden_states return hidden_states def load_weights( @@ -462,3 +495,24 @@ class Base(nn.Module, VllmModel, SupportsQuant, SupportsLoRA, SupportsPP): f"Transformers modeling backend requires transformers>={required} " f"for {feature}, but got {installed}" ) + + def set_aux_hidden_state_layers(self, layers: tuple[int, ...]) -> None: + self.check_version("5.0.0.dev0", "Eagle3 support") + from transformers.utils.generic import OutputRecorder + + # The default value in PreTrainedModel is None + if self.model._can_record_outputs is None: + self.model._can_record_outputs = {} + + target_class = self._target_class + for layer in layers: + # layer - 1 because we want the input to the layer + layer_name = self._layer_names[layer - 1] + layer_key = f"aux_hidden_state_{layer}" + aux_hidden_state_i = OutputRecorder(target_class, layer_name=layer_name) + self.model._can_record_outputs[layer_key] = aux_hidden_state_i + self._output_aux_hidden_states_kwargs[f"output_{layer_key}"] = True + + def get_eagle3_aux_hidden_state_layers(self) -> tuple[int, ...]: + num_layers = self.text_config.num_hidden_layers + return (2, num_layers // 2, num_layers - 3)