Add Eagle and Eagle3 support to Transformers modeling backend (#30340)

Signed-off-by: Harry Mellor <19981378+hmellor@users.noreply.github.com>
This commit is contained in:
Harry Mellor 2025-12-11 17:02:10 +00:00 committed by GitHub
parent aa3c250c48
commit 8781cd6b88
No known key found for this signature in database
GPG Key ID: B5690EEEBB952194
2 changed files with 94 additions and 8 deletions

View File

@ -280,9 +280,20 @@ def test_speculators_model_integration(
@pytest.mark.parametrize(
["model_setup", "mm_enabled", "enable_chunked_prefill"],
["model_setup", "mm_enabled", "enable_chunked_prefill", "model_impl"],
[
(("eagle3", "Qwen/Qwen3-8B", "AngelSlim/Qwen3-8B_eagle3", 1), False, False),
(
("eagle3", "Qwen/Qwen3-8B", "AngelSlim/Qwen3-8B_eagle3", 1),
False,
False,
"auto",
),
(
("eagle3", "Qwen/Qwen3-8B", "AngelSlim/Qwen3-8B_eagle3", 1),
False,
False,
"transformers",
),
pytest.param(
(
"eagle3",
@ -292,6 +303,7 @@ def test_speculators_model_integration(
),
False,
False,
"auto",
marks=pytest.mark.skip(
reason="architecture of its eagle3 is LlamaForCausalLMEagle3"
),
@ -305,6 +317,7 @@ def test_speculators_model_integration(
),
False,
False,
"auto",
marks=pytest.mark.skip(
reason="Skipping due to its head_dim not being a a multiple of 32"
),
@ -318,6 +331,7 @@ def test_speculators_model_integration(
),
False,
True,
"auto",
marks=large_gpu_mark(min_gb=40),
), # works on 4x H100
(
@ -329,6 +343,7 @@ def test_speculators_model_integration(
),
False,
False,
"auto",
),
pytest.param(
(
@ -339,6 +354,7 @@ def test_speculators_model_integration(
),
False,
False,
"auto",
marks=large_gpu_mark(min_gb=80),
), # works on 4x H100
pytest.param(
@ -350,6 +366,7 @@ def test_speculators_model_integration(
),
True,
True,
"auto",
marks=large_gpu_mark(min_gb=80),
), # works on 4x H100
(
@ -361,10 +378,12 @@ def test_speculators_model_integration(
),
False,
False,
"auto",
),
],
ids=[
"qwen3_eagle3",
"qwen3_eagle3-transformers",
"qwen3_vl_eagle3",
"qwen2_5_vl_eagle3",
"llama3_eagle",
@ -381,6 +400,7 @@ def test_eagle_correctness(
model_setup: tuple[str, str, str, int],
mm_enabled: bool,
enable_chunked_prefill: bool,
model_impl: str,
attn_backend: str,
):
if attn_backend == "TREE_ATTN":
@ -389,6 +409,17 @@ def test_eagle_correctness(
"TREE_ATTN is flaky in the test disable for now until it can be "
"resolved (see https://github.com/vllm-project/vllm/issues/22922)"
)
if model_impl == "transformers":
import transformers
from packaging.version import Version
installed = Version(transformers.__version__)
required = Version("5.0.0.dev")
if installed < required:
pytest.skip(
"Eagle3 with the Transformers modeling backend requires "
f"transformers>={required}, but got {installed}"
)
# Generate test prompts inside the function instead of using fixture
test_prompts = get_test_prompts(mm_enabled)
@ -448,6 +479,7 @@ def test_eagle_correctness(
max_model_len=max_model_len,
max_num_batched_tokens=max_num_batched_tokens,
enable_chunked_prefill=enable_chunked_prefill,
model_impl=model_impl,
)
spec_outputs = spec_llm.chat(test_prompts, sampling_config)
matches = 0

View File

@ -36,6 +36,8 @@ from vllm.distributed.utils import get_pp_indices
from vllm.logger import init_logger
from vllm.model_executor.layers.vocab_parallel_embedding import VocabParallelEmbedding
from vllm.model_executor.models.interfaces import (
SupportsEagle,
SupportsEagle3,
SupportsLoRA,
SupportsPP,
SupportsQuant,
@ -92,7 +94,15 @@ def vllm_flash_attention_forward(
ALL_ATTENTION_FUNCTIONS["vllm"] = vllm_flash_attention_forward
class Base(nn.Module, VllmModel, SupportsQuant, SupportsLoRA, SupportsPP):
class Base(
nn.Module,
VllmModel,
SupportsQuant,
SupportsLoRA,
SupportsPP,
SupportsEagle,
SupportsEagle3,
):
embedding_modules = ["embed_tokens"] # TODO transformers will have a util to get it
hf_to_vllm_mapper = WeightsMapper(
orig_to_new_prefix={
@ -131,17 +141,24 @@ class Base(nn.Module, VllmModel, SupportsQuant, SupportsLoRA, SupportsPP):
self.pp_group = get_pp_group()
self.tp_group = get_tp_group()
# Weights to skip in `self.load_weights`
# Attrs for weight loading (see self.load_weights)
self.skip_prefixes: list[str] = []
"""Skip loading weights whose qualname starts with these prefixes."""
self.skip_substrs: list[str] = []
"""Skip loading weights whose qualname contains these substrings."""
self.ignore_unexpected_prefixes: list[str] = []
"""Ignore unexpected weights whose qualname starts with these prefixes.
"""
"""Ignore unexpected weights whose qualname starts with these prefixes."""
self.ignore_unexpected_suffixes: list[str] = []
"""Ignore unexpected weights whose qualname ends with these suffixes."""
# Attrs for Eagle3 (see self.set_aux_hidden_state_layers)
self._target_class: type[nn.Module] = nn.Module
"""Target class for Eagle3 aux hidden state recording."""
self._layer_names: dict[int, str] = {}
"""Mapping from layer index to layer name for Eagle3."""
self._output_aux_hidden_states_kwargs: dict[str, bool] = {}
"""Kwargs to pass to model forward for Eagle3 aux hidden states."""
if self.quant_config:
quant_method_name = self.quant_config.get_name()
# Check for unsupported quantization methods.
@ -278,6 +295,15 @@ class Base(nn.Module, VllmModel, SupportsQuant, SupportsLoRA, SupportsPP):
for child_name, child_module in module.named_children():
new_module = child_module
qual_name = maybe_prefix(prefix, child_name)
# Populate Eagle3 attrs
if (
isinstance(module, nn.ModuleList)
and len(module) == self.text_config.num_hidden_layers
):
self._target_class = type(child_module)
layer_name = qual_name.removeprefix("model.")
self._layer_names[int(child_name)] = layer_name
# Replace modules as needed
if isinstance(child_module, nn.Linear):
generator = (p for p in tp_plan if re.match(p, qual_name))
pattern = next(generator, None)
@ -425,19 +451,26 @@ class Base(nn.Module, VllmModel, SupportsQuant, SupportsLoRA, SupportsPP):
else:
position_ids = positions[None, ...]
hidden_states = self.model(
outputs = self.model(
input_ids=input_ids,
inputs_embeds=inputs_embeds,
use_cache=False,
position_ids=position_ids,
attention_instances=self.attention_instances,
return_dict=False,
**self._output_aux_hidden_states_kwargs,
**kwargs,
)[0][0, ...] # we remove batch dimension for now
)
# We must remove the batch dimension from these outputs
hidden_states = outputs[0][0, ...]
if self._output_aux_hidden_states_kwargs:
aux_hidden_states = [x[0][0, ...] for x in outputs[1:]]
if not self.pp_group.is_last_rank:
return IntermediateTensors({"hidden_states": hidden_states})
if self._output_aux_hidden_states_kwargs and len(aux_hidden_states) > 0:
return hidden_states, aux_hidden_states
return hidden_states
def load_weights(
@ -462,3 +495,24 @@ class Base(nn.Module, VllmModel, SupportsQuant, SupportsLoRA, SupportsPP):
f"Transformers modeling backend requires transformers>={required} "
f"for {feature}, but got {installed}"
)
def set_aux_hidden_state_layers(self, layers: tuple[int, ...]) -> None:
self.check_version("5.0.0.dev0", "Eagle3 support")
from transformers.utils.generic import OutputRecorder
# The default value in PreTrainedModel is None
if self.model._can_record_outputs is None:
self.model._can_record_outputs = {}
target_class = self._target_class
for layer in layers:
# layer - 1 because we want the input to the layer
layer_name = self._layer_names[layer - 1]
layer_key = f"aux_hidden_state_{layer}"
aux_hidden_state_i = OutputRecorder(target_class, layer_name=layer_name)
self.model._can_record_outputs[layer_key] = aux_hidden_state_i
self._output_aux_hidden_states_kwargs[f"output_{layer_key}"] = True
def get_eagle3_aux_hidden_state_layers(self) -> tuple[int, ...]:
num_layers = self.text_config.num_hidden_layers
return (2, num_layers // 2, num_layers - 3)