[Meta] Llama4 EAGLE Support (#20591)

Signed-off-by: qizixi <qizixi@meta.com>
Co-authored-by: qizixi <qizixi@meta.com>
This commit is contained in:
zhiweiz 2025-07-15 21:14:15 -07:00 committed by GitHub
parent 1eb2b9c102
commit c11013db8b
No known key found for this signature in database
GPG Key ID: B5690EEEBB952194
6 changed files with 255 additions and 15 deletions

View File

@ -84,6 +84,7 @@ def main():
gpu_memory_utilization=0.8,
speculative_config=speculative_config,
disable_log_stats=False,
max_model_len=16384,
)
sampling_params = SamplingParams(temperature=args.temp, max_tokens=args.output_len)

View File

@ -465,6 +465,11 @@ _SPECULATIVE_DECODING_EXAMPLE_MODELS = {
trust_remote_code=True,
speculative_model="yuhuili/EAGLE3-LLaMA3.1-Instruct-8B",
tokenizer="meta-llama/Llama-3.1-8B-Instruct"),
"EagleLlama4ForCausalLM": _HfExamplesInfo(
"morgendave/EAGLE-Llama-4-Scout-17B-16E-Instruct",
trust_remote_code=True,
speculative_model="morgendave/EAGLE-Llama-4-Scout-17B-16E-Instruct",
tokenizer="meta-llama/Llama-4-Scout-17B-16E-Instruct"), # noqa: E501
"EagleMiniCPMForCausalLM": _HfExamplesInfo("openbmb/MiniCPM-1B-sft-bf16",
trust_remote_code=True,
is_available_online=False,

View File

@ -36,6 +36,11 @@ def test_can_initialize(model_arch: str, monkeypatch: pytest.MonkeyPatch):
"KimiVLForConditionalGeneration"):
pytest.skip("Avoid OOM")
if model_arch in ("Llama4ForCausalLM", "EagleLlama4ForCausalLM"):
from vllm.model_executor.models.llama4 import Llama4ForCausalLM
from vllm.model_executor.models.registry import ModelRegistry
ModelRegistry.register_model("Llama4ForCausalLM", Llama4ForCausalLM)
# Avoid OOM and reduce initialization time by only using 1 layer
def hf_overrides(hf_config: PretrainedConfig) -> PretrainedConfig:
hf_config.update(model_info.hf_overrides)

View File

@ -6,8 +6,10 @@ import random
from typing import Any
import pytest
import torch
from vllm import LLM, SamplingParams
from vllm.distributed import cleanup_dist_env_and_memory
@pytest.fixture
@ -53,14 +55,6 @@ def model_name():
return "meta-llama/Llama-3.1-8B-Instruct"
def eagle_model_name():
return "yuhuili/EAGLE-LLaMA3.1-Instruct-8B"
def eagle3_model_name():
return "yuhuili/EAGLE3-LLaMA3.1-Instruct-8B"
def test_ngram_correctness(
monkeypatch: pytest.MonkeyPatch,
test_prompts: list[list[dict[str, Any]]],
@ -77,6 +71,8 @@ def test_ngram_correctness(
ref_llm = LLM(model=model_name, max_model_len=1024)
ref_outputs = ref_llm.chat(test_prompts, sampling_config)
del ref_llm
torch.cuda.empty_cache()
cleanup_dist_env_and_memory()
spec_llm = LLM(
model=model_name,
@ -103,34 +99,50 @@ def test_ngram_correctness(
# Upon failure, inspect the outputs to check for inaccuracy.
assert matches > int(0.7 * len(ref_outputs))
del spec_llm
torch.cuda.empty_cache()
cleanup_dist_env_and_memory()
@pytest.mark.parametrize("use_eagle3", [False, True], ids=["eagle", "eagle3"])
@pytest.mark.parametrize("model_setup", [
("eagle", "meta-llama/Llama-3.1-8B-Instruct",
"yuhuili/EAGLE-LLaMA3.1-Instruct-8B", 1),
("eagle3", "meta-llama/Llama-3.1-8B-Instruct",
"yuhuili/EAGLE3-LLaMA3.1-Instruct-8B", 1),
pytest.param(
("eagle", "meta-llama/Llama-4-Scout-17B-16E-Instruct",
"morgendave/EAGLE-Llama-4-Scout-17B-16E-Instruct", 4),
marks=pytest.mark.skip(reason="Skipping due to CI OOM issues")),
],
ids=["llama3_eagle", "llama3_eagle3", "llama4_eagle"])
def test_eagle_correctness(
monkeypatch: pytest.MonkeyPatch,
test_prompts: list[list[dict[str, Any]]],
sampling_config: SamplingParams,
model_name: str,
use_eagle3: bool,
model_setup: tuple[str, str, str, int],
):
'''
Compare the outputs of a original LLM and a speculative LLM
should be the same when using eagle speculative decoding.
model_setup: (method, model_name, eagle_model_name, tp_size)
'''
with monkeypatch.context() as m:
m.setenv("VLLM_USE_V1", "1")
method, model_name, spec_model_name, tp_size = model_setup
ref_llm = LLM(model=model_name, max_model_len=2048)
ref_llm = LLM(model=model_name,
max_model_len=2048,
tensor_parallel_size=tp_size)
ref_outputs = ref_llm.chat(test_prompts, sampling_config)
del ref_llm
torch.cuda.empty_cache()
cleanup_dist_env_and_memory()
spec_model_name = eagle3_model_name(
) if use_eagle3 else eagle_model_name()
spec_llm = LLM(
model=model_name,
trust_remote_code=True,
tensor_parallel_size=tp_size,
speculative_config={
"method": "eagle3" if use_eagle3 else "eagle",
"method": method,
"model": spec_model_name,
"num_speculative_tokens": 3,
"max_model_len": 2048,
@ -152,3 +164,5 @@ def test_eagle_correctness(
# Upon failure, inspect the outputs to check for inaccuracy.
assert matches > int(0.66 * len(ref_outputs))
del spec_llm
torch.cuda.empty_cache()
cleanup_dist_env_and_memory()

View File

@ -0,0 +1,214 @@
# SPDX-License-Identifier: Apache-2.0
# SPDX-FileCopyrightText: Copyright contributors to the vLLM project
# Copyright 2025 the LLAMA4, Meta Inc., vLLM, and HuggingFace Inc. team.
# All rights reserved.
#
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
from collections.abc import Iterable
from typing import Optional
import torch
import torch.nn as nn
from vllm.compilation.decorators import support_torch_compile
from vllm.config import VllmConfig
from vllm.distributed.parallel_state import get_pp_group
from vllm.logger import init_logger
from vllm.model_executor.layers.layernorm import RMSNorm
from vllm.model_executor.layers.logits_processor import LogitsProcessor
from vllm.model_executor.layers.quantization.base_config import (
QuantizationConfig)
from vllm.model_executor.layers.quantization.torchao import TorchAOConfig
from vllm.model_executor.layers.vocab_parallel_embedding import (
VocabParallelEmbedding)
from vllm.model_executor.model_loader.weight_utils import default_weight_loader
from vllm.model_executor.models.llama4 import (Llama4DecoderLayer,
Llama4ForCausalLM)
from vllm.model_executor.models.utils import extract_layer_index
from .utils import AutoWeightsLoader, maybe_prefix
logger = init_logger(__name__)
@support_torch_compile
class LlamaModel(nn.Module):
def __init__(
self,
*,
vllm_config: VllmConfig,
prefix: str = "",
start_layer_id: int = 0,
quant_config: Optional[QuantizationConfig] = None,
) -> None:
super().__init__()
self.config = (
vllm_config.speculative_config.draft_model_config.hf_config)
self.validate_and_update_config(start_layer_id, quant_config)
self.vocab_size = self.config.vocab_size
self.embed_tokens = VocabParallelEmbedding(
self.config.vocab_size,
self.config.hidden_size,
prefix=maybe_prefix(prefix, "embed_tokens"),
)
self.layers = nn.ModuleList([
Llama4DecoderLayer(
self.config,
quant_config=quant_config,
prefix=maybe_prefix(prefix, f"layers.{i + start_layer_id}"),
) for i in range(self.config.num_hidden_layers)
])
self.fc = torch.nn.Linear(self.config.hidden_size * 2,
self.config.hidden_size,
bias=False)
self.norm = RMSNorm(self.config.hidden_size,
eps=self.config.rms_norm_eps)
def forward(
self,
input_ids: Optional[torch.Tensor],
positions: torch.Tensor,
hidden_states: torch.Tensor,
) -> tuple[torch.Tensor, torch.Tensor]:
input_embeds = self.embed_tokens(input_ids)
hidden_states = self.fc(
torch.cat((input_embeds, hidden_states), dim=-1))
residual = None
for layer in self.layers:
hidden_states, residual = layer(
positions,
hidden_states,
residual,
)
hidden_states, _ = self.norm(hidden_states, residual)
return hidden_states, hidden_states
def load_weights(self, weights: Iterable[tuple[str,
torch.Tensor]]) -> set[str]:
stacked_params_mapping = [
# (param_name, shard_name, shard_id)
(".qkv_proj", ".q_proj", "q"),
(".qkv_proj", ".k_proj", "k"),
(".qkv_proj", ".v_proj", "v"),
(".gate_up_proj", ".gate_proj", 0),
(".gate_up_proj", ".up_proj", 1),
]
params_dict = dict(self.named_parameters())
loaded_params: set[str] = set()
for name, loaded_weight in weights:
name = name.removeprefix("model.")
for param_name, weight_name, shard_id in stacked_params_mapping:
if weight_name not in name:
continue
name = name.replace(weight_name, param_name)
param = params_dict[name]
weight_loader = param.weight_loader
weight_loader(param, loaded_weight, shard_id)
break
else:
# if PP disabled then draft will share embed with target
if get_pp_group().world_size == 1 and \
"embed_tokens." in name:
continue
param = params_dict[name]
weight_loader = getattr(param, "weight_loader",
default_weight_loader)
weight_loader(param, loaded_weight)
loaded_params.add(name)
for name in params_dict:
# if PP disabled then draft will share embed with target
if get_pp_group().world_size == 1 and \
"embed_tokens." in name:
continue
assert name in loaded_params, f"{name} is not loaded!"
return loaded_params
def validate_and_update_config(
self,
start_layer_id: int,
quant_config: Optional[QuantizationConfig] = None) -> None:
# yoco and moe is not supported by draft model yet
assert self.config.yoco_global_kv_layer is None
assert self.config.yoco_local_kv_layer is None
assert len(self.config.moe_layers) == 0
# draft model layer index is increased by start_layer_id,
# so we need to pad relevant configs accordingly
self.config.no_rope_layers = [
0
] * start_layer_id + self.config.no_rope_layers
# currently only TorchAO quantization is supported
if isinstance(quant_config, TorchAOConfig):
def pad_layer_name(layer: str) -> str:
layer_index = extract_layer_index(layer)
return layer.replace(str(layer_index),
str(layer_index + start_layer_id))
quant_config.torchao_config.module_fqn_to_config = {
pad_layer_name(layer): quantization
for layer, quantization in
quant_config.torchao_config.module_fqn_to_config.items()
}
class EagleLlama4ForCausalLM(Llama4ForCausalLM):
def __init__(self, *, vllm_config: VllmConfig, prefix: str = ""):
nn.Module.__init__(self)
self.config = (
vllm_config.speculative_config.draft_model_config.hf_config)
target_layer_num = vllm_config.model_config.get_num_layers(
vllm_config.parallel_config)
# draft model quantization config may differ from target model
quant_config = VllmConfig.get_quantization_config(
vllm_config.speculative_config.draft_model_config,
vllm_config.load_config)
self.model = LlamaModel(vllm_config=vllm_config,
prefix="model",
start_layer_id=target_layer_num,
quant_config=quant_config)
logit_scale = getattr(self.config, "logit_scale", 1.0)
self.logits_processor = LogitsProcessor(self.config.vocab_size,
scale=logit_scale)
def forward(
self,
input_ids: torch.Tensor,
positions: torch.Tensor,
hidden_states: torch.Tensor,
) -> tuple[torch.Tensor, torch.Tensor]:
return self.model(input_ids, positions, hidden_states)
def load_weights(self, weights: Iterable[tuple[str,
torch.Tensor]]) -> None:
loader = AutoWeightsLoader(
self,
# lm_head is tied with target model (Llama4ForCausalLM)
skip_prefixes=(["lm_head."]),
)
model_weights = {}
weights = [
self.permute_qk_weight_for_rotary(name, loaded_weight)
for name, loaded_weight in weights
]
for name, loaded_weight in weights:
if "lm_head" not in name:
name = "model." + name
model_weights[name] = loaded_weight
loader.load_weights(model_weights.items())

View File

@ -244,6 +244,7 @@ _SPECULATIVE_DECODING_MODELS = {
"MiMoMTPModel": ("mimo_mtp", "MiMoMTP"),
"EAGLEModel": ("eagle", "EAGLE"),
"EagleLlamaForCausalLM": ("llama_eagle", "EagleLlamaForCausalLM"),
"EagleLlama4ForCausalLM": ("llama4_eagle", "EagleLlama4ForCausalLM"),
"EagleMiniCPMForCausalLM": ("minicpm_eagle", "EagleMiniCPMForCausalLM"),
"Eagle3LlamaForCausalLM": ("llama_eagle3", "Eagle3LlamaForCausalLM"),
"DeepSeekMTPModel": ("deepseek_mtp", "DeepSeekMTP"),