mirror of
https://git.datalinker.icu/vllm-project/vllm.git
synced 2025-12-25 19:21:53 +08:00
Add Eagle and Eagle3 support to Transformers modeling backend (#30340)
Signed-off-by: Harry Mellor <19981378+hmellor@users.noreply.github.com>
This commit is contained in:
parent
aa3c250c48
commit
8781cd6b88
@ -280,9 +280,20 @@ def test_speculators_model_integration(
|
||||
|
||||
|
||||
@pytest.mark.parametrize(
|
||||
["model_setup", "mm_enabled", "enable_chunked_prefill"],
|
||||
["model_setup", "mm_enabled", "enable_chunked_prefill", "model_impl"],
|
||||
[
|
||||
(("eagle3", "Qwen/Qwen3-8B", "AngelSlim/Qwen3-8B_eagle3", 1), False, False),
|
||||
(
|
||||
("eagle3", "Qwen/Qwen3-8B", "AngelSlim/Qwen3-8B_eagle3", 1),
|
||||
False,
|
||||
False,
|
||||
"auto",
|
||||
),
|
||||
(
|
||||
("eagle3", "Qwen/Qwen3-8B", "AngelSlim/Qwen3-8B_eagle3", 1),
|
||||
False,
|
||||
False,
|
||||
"transformers",
|
||||
),
|
||||
pytest.param(
|
||||
(
|
||||
"eagle3",
|
||||
@ -292,6 +303,7 @@ def test_speculators_model_integration(
|
||||
),
|
||||
False,
|
||||
False,
|
||||
"auto",
|
||||
marks=pytest.mark.skip(
|
||||
reason="architecture of its eagle3 is LlamaForCausalLMEagle3"
|
||||
),
|
||||
@ -305,6 +317,7 @@ def test_speculators_model_integration(
|
||||
),
|
||||
False,
|
||||
False,
|
||||
"auto",
|
||||
marks=pytest.mark.skip(
|
||||
reason="Skipping due to its head_dim not being a a multiple of 32"
|
||||
),
|
||||
@ -318,6 +331,7 @@ def test_speculators_model_integration(
|
||||
),
|
||||
False,
|
||||
True,
|
||||
"auto",
|
||||
marks=large_gpu_mark(min_gb=40),
|
||||
), # works on 4x H100
|
||||
(
|
||||
@ -329,6 +343,7 @@ def test_speculators_model_integration(
|
||||
),
|
||||
False,
|
||||
False,
|
||||
"auto",
|
||||
),
|
||||
pytest.param(
|
||||
(
|
||||
@ -339,6 +354,7 @@ def test_speculators_model_integration(
|
||||
),
|
||||
False,
|
||||
False,
|
||||
"auto",
|
||||
marks=large_gpu_mark(min_gb=80),
|
||||
), # works on 4x H100
|
||||
pytest.param(
|
||||
@ -350,6 +366,7 @@ def test_speculators_model_integration(
|
||||
),
|
||||
True,
|
||||
True,
|
||||
"auto",
|
||||
marks=large_gpu_mark(min_gb=80),
|
||||
), # works on 4x H100
|
||||
(
|
||||
@ -361,10 +378,12 @@ def test_speculators_model_integration(
|
||||
),
|
||||
False,
|
||||
False,
|
||||
"auto",
|
||||
),
|
||||
],
|
||||
ids=[
|
||||
"qwen3_eagle3",
|
||||
"qwen3_eagle3-transformers",
|
||||
"qwen3_vl_eagle3",
|
||||
"qwen2_5_vl_eagle3",
|
||||
"llama3_eagle",
|
||||
@ -381,6 +400,7 @@ def test_eagle_correctness(
|
||||
model_setup: tuple[str, str, str, int],
|
||||
mm_enabled: bool,
|
||||
enable_chunked_prefill: bool,
|
||||
model_impl: str,
|
||||
attn_backend: str,
|
||||
):
|
||||
if attn_backend == "TREE_ATTN":
|
||||
@ -389,6 +409,17 @@ def test_eagle_correctness(
|
||||
"TREE_ATTN is flaky in the test disable for now until it can be "
|
||||
"resolved (see https://github.com/vllm-project/vllm/issues/22922)"
|
||||
)
|
||||
if model_impl == "transformers":
|
||||
import transformers
|
||||
from packaging.version import Version
|
||||
|
||||
installed = Version(transformers.__version__)
|
||||
required = Version("5.0.0.dev")
|
||||
if installed < required:
|
||||
pytest.skip(
|
||||
"Eagle3 with the Transformers modeling backend requires "
|
||||
f"transformers>={required}, but got {installed}"
|
||||
)
|
||||
|
||||
# Generate test prompts inside the function instead of using fixture
|
||||
test_prompts = get_test_prompts(mm_enabled)
|
||||
@ -448,6 +479,7 @@ def test_eagle_correctness(
|
||||
max_model_len=max_model_len,
|
||||
max_num_batched_tokens=max_num_batched_tokens,
|
||||
enable_chunked_prefill=enable_chunked_prefill,
|
||||
model_impl=model_impl,
|
||||
)
|
||||
spec_outputs = spec_llm.chat(test_prompts, sampling_config)
|
||||
matches = 0
|
||||
|
||||
@ -36,6 +36,8 @@ from vllm.distributed.utils import get_pp_indices
|
||||
from vllm.logger import init_logger
|
||||
from vllm.model_executor.layers.vocab_parallel_embedding import VocabParallelEmbedding
|
||||
from vllm.model_executor.models.interfaces import (
|
||||
SupportsEagle,
|
||||
SupportsEagle3,
|
||||
SupportsLoRA,
|
||||
SupportsPP,
|
||||
SupportsQuant,
|
||||
@ -92,7 +94,15 @@ def vllm_flash_attention_forward(
|
||||
ALL_ATTENTION_FUNCTIONS["vllm"] = vllm_flash_attention_forward
|
||||
|
||||
|
||||
class Base(nn.Module, VllmModel, SupportsQuant, SupportsLoRA, SupportsPP):
|
||||
class Base(
|
||||
nn.Module,
|
||||
VllmModel,
|
||||
SupportsQuant,
|
||||
SupportsLoRA,
|
||||
SupportsPP,
|
||||
SupportsEagle,
|
||||
SupportsEagle3,
|
||||
):
|
||||
embedding_modules = ["embed_tokens"] # TODO transformers will have a util to get it
|
||||
hf_to_vllm_mapper = WeightsMapper(
|
||||
orig_to_new_prefix={
|
||||
@ -131,17 +141,24 @@ class Base(nn.Module, VllmModel, SupportsQuant, SupportsLoRA, SupportsPP):
|
||||
self.pp_group = get_pp_group()
|
||||
self.tp_group = get_tp_group()
|
||||
|
||||
# Weights to skip in `self.load_weights`
|
||||
# Attrs for weight loading (see self.load_weights)
|
||||
self.skip_prefixes: list[str] = []
|
||||
"""Skip loading weights whose qualname starts with these prefixes."""
|
||||
self.skip_substrs: list[str] = []
|
||||
"""Skip loading weights whose qualname contains these substrings."""
|
||||
self.ignore_unexpected_prefixes: list[str] = []
|
||||
"""Ignore unexpected weights whose qualname starts with these prefixes.
|
||||
"""
|
||||
"""Ignore unexpected weights whose qualname starts with these prefixes."""
|
||||
self.ignore_unexpected_suffixes: list[str] = []
|
||||
"""Ignore unexpected weights whose qualname ends with these suffixes."""
|
||||
|
||||
# Attrs for Eagle3 (see self.set_aux_hidden_state_layers)
|
||||
self._target_class: type[nn.Module] = nn.Module
|
||||
"""Target class for Eagle3 aux hidden state recording."""
|
||||
self._layer_names: dict[int, str] = {}
|
||||
"""Mapping from layer index to layer name for Eagle3."""
|
||||
self._output_aux_hidden_states_kwargs: dict[str, bool] = {}
|
||||
"""Kwargs to pass to model forward for Eagle3 aux hidden states."""
|
||||
|
||||
if self.quant_config:
|
||||
quant_method_name = self.quant_config.get_name()
|
||||
# Check for unsupported quantization methods.
|
||||
@ -278,6 +295,15 @@ class Base(nn.Module, VllmModel, SupportsQuant, SupportsLoRA, SupportsPP):
|
||||
for child_name, child_module in module.named_children():
|
||||
new_module = child_module
|
||||
qual_name = maybe_prefix(prefix, child_name)
|
||||
# Populate Eagle3 attrs
|
||||
if (
|
||||
isinstance(module, nn.ModuleList)
|
||||
and len(module) == self.text_config.num_hidden_layers
|
||||
):
|
||||
self._target_class = type(child_module)
|
||||
layer_name = qual_name.removeprefix("model.")
|
||||
self._layer_names[int(child_name)] = layer_name
|
||||
# Replace modules as needed
|
||||
if isinstance(child_module, nn.Linear):
|
||||
generator = (p for p in tp_plan if re.match(p, qual_name))
|
||||
pattern = next(generator, None)
|
||||
@ -425,19 +451,26 @@ class Base(nn.Module, VllmModel, SupportsQuant, SupportsLoRA, SupportsPP):
|
||||
else:
|
||||
position_ids = positions[None, ...]
|
||||
|
||||
hidden_states = self.model(
|
||||
outputs = self.model(
|
||||
input_ids=input_ids,
|
||||
inputs_embeds=inputs_embeds,
|
||||
use_cache=False,
|
||||
position_ids=position_ids,
|
||||
attention_instances=self.attention_instances,
|
||||
return_dict=False,
|
||||
**self._output_aux_hidden_states_kwargs,
|
||||
**kwargs,
|
||||
)[0][0, ...] # we remove batch dimension for now
|
||||
)
|
||||
# We must remove the batch dimension from these outputs
|
||||
hidden_states = outputs[0][0, ...]
|
||||
if self._output_aux_hidden_states_kwargs:
|
||||
aux_hidden_states = [x[0][0, ...] for x in outputs[1:]]
|
||||
|
||||
if not self.pp_group.is_last_rank:
|
||||
return IntermediateTensors({"hidden_states": hidden_states})
|
||||
|
||||
if self._output_aux_hidden_states_kwargs and len(aux_hidden_states) > 0:
|
||||
return hidden_states, aux_hidden_states
|
||||
return hidden_states
|
||||
|
||||
def load_weights(
|
||||
@ -462,3 +495,24 @@ class Base(nn.Module, VllmModel, SupportsQuant, SupportsLoRA, SupportsPP):
|
||||
f"Transformers modeling backend requires transformers>={required} "
|
||||
f"for {feature}, but got {installed}"
|
||||
)
|
||||
|
||||
def set_aux_hidden_state_layers(self, layers: tuple[int, ...]) -> None:
|
||||
self.check_version("5.0.0.dev0", "Eagle3 support")
|
||||
from transformers.utils.generic import OutputRecorder
|
||||
|
||||
# The default value in PreTrainedModel is None
|
||||
if self.model._can_record_outputs is None:
|
||||
self.model._can_record_outputs = {}
|
||||
|
||||
target_class = self._target_class
|
||||
for layer in layers:
|
||||
# layer - 1 because we want the input to the layer
|
||||
layer_name = self._layer_names[layer - 1]
|
||||
layer_key = f"aux_hidden_state_{layer}"
|
||||
aux_hidden_state_i = OutputRecorder(target_class, layer_name=layer_name)
|
||||
self.model._can_record_outputs[layer_key] = aux_hidden_state_i
|
||||
self._output_aux_hidden_states_kwargs[f"output_{layer_key}"] = True
|
||||
|
||||
def get_eagle3_aux_hidden_state_layers(self) -> tuple[int, ...]:
|
||||
num_layers = self.text_config.num_hidden_layers
|
||||
return (2, num_layers // 2, num_layers - 3)
|
||||
|
||||
Loading…
x
Reference in New Issue
Block a user