mirror of
https://git.datalinker.icu/vllm-project/vllm.git
synced 2025-12-09 14:35:27 +08:00
[Meta] Llama4 EAGLE Support (#20591)
Signed-off-by: qizixi <qizixi@meta.com> Co-authored-by: qizixi <qizixi@meta.com>
This commit is contained in:
parent
1eb2b9c102
commit
c11013db8b
@ -84,6 +84,7 @@ def main():
|
||||
gpu_memory_utilization=0.8,
|
||||
speculative_config=speculative_config,
|
||||
disable_log_stats=False,
|
||||
max_model_len=16384,
|
||||
)
|
||||
|
||||
sampling_params = SamplingParams(temperature=args.temp, max_tokens=args.output_len)
|
||||
|
||||
@ -465,6 +465,11 @@ _SPECULATIVE_DECODING_EXAMPLE_MODELS = {
|
||||
trust_remote_code=True,
|
||||
speculative_model="yuhuili/EAGLE3-LLaMA3.1-Instruct-8B",
|
||||
tokenizer="meta-llama/Llama-3.1-8B-Instruct"),
|
||||
"EagleLlama4ForCausalLM": _HfExamplesInfo(
|
||||
"morgendave/EAGLE-Llama-4-Scout-17B-16E-Instruct",
|
||||
trust_remote_code=True,
|
||||
speculative_model="morgendave/EAGLE-Llama-4-Scout-17B-16E-Instruct",
|
||||
tokenizer="meta-llama/Llama-4-Scout-17B-16E-Instruct"), # noqa: E501
|
||||
"EagleMiniCPMForCausalLM": _HfExamplesInfo("openbmb/MiniCPM-1B-sft-bf16",
|
||||
trust_remote_code=True,
|
||||
is_available_online=False,
|
||||
|
||||
@ -36,6 +36,11 @@ def test_can_initialize(model_arch: str, monkeypatch: pytest.MonkeyPatch):
|
||||
"KimiVLForConditionalGeneration"):
|
||||
pytest.skip("Avoid OOM")
|
||||
|
||||
if model_arch in ("Llama4ForCausalLM", "EagleLlama4ForCausalLM"):
|
||||
from vllm.model_executor.models.llama4 import Llama4ForCausalLM
|
||||
from vllm.model_executor.models.registry import ModelRegistry
|
||||
ModelRegistry.register_model("Llama4ForCausalLM", Llama4ForCausalLM)
|
||||
|
||||
# Avoid OOM and reduce initialization time by only using 1 layer
|
||||
def hf_overrides(hf_config: PretrainedConfig) -> PretrainedConfig:
|
||||
hf_config.update(model_info.hf_overrides)
|
||||
|
||||
@ -6,8 +6,10 @@ import random
|
||||
from typing import Any
|
||||
|
||||
import pytest
|
||||
import torch
|
||||
|
||||
from vllm import LLM, SamplingParams
|
||||
from vllm.distributed import cleanup_dist_env_and_memory
|
||||
|
||||
|
||||
@pytest.fixture
|
||||
@ -53,14 +55,6 @@ def model_name():
|
||||
return "meta-llama/Llama-3.1-8B-Instruct"
|
||||
|
||||
|
||||
def eagle_model_name():
|
||||
return "yuhuili/EAGLE-LLaMA3.1-Instruct-8B"
|
||||
|
||||
|
||||
def eagle3_model_name():
|
||||
return "yuhuili/EAGLE3-LLaMA3.1-Instruct-8B"
|
||||
|
||||
|
||||
def test_ngram_correctness(
|
||||
monkeypatch: pytest.MonkeyPatch,
|
||||
test_prompts: list[list[dict[str, Any]]],
|
||||
@ -77,6 +71,8 @@ def test_ngram_correctness(
|
||||
ref_llm = LLM(model=model_name, max_model_len=1024)
|
||||
ref_outputs = ref_llm.chat(test_prompts, sampling_config)
|
||||
del ref_llm
|
||||
torch.cuda.empty_cache()
|
||||
cleanup_dist_env_and_memory()
|
||||
|
||||
spec_llm = LLM(
|
||||
model=model_name,
|
||||
@ -103,34 +99,50 @@ def test_ngram_correctness(
|
||||
# Upon failure, inspect the outputs to check for inaccuracy.
|
||||
assert matches > int(0.7 * len(ref_outputs))
|
||||
del spec_llm
|
||||
torch.cuda.empty_cache()
|
||||
cleanup_dist_env_and_memory()
|
||||
|
||||
|
||||
@pytest.mark.parametrize("use_eagle3", [False, True], ids=["eagle", "eagle3"])
|
||||
@pytest.mark.parametrize("model_setup", [
|
||||
("eagle", "meta-llama/Llama-3.1-8B-Instruct",
|
||||
"yuhuili/EAGLE-LLaMA3.1-Instruct-8B", 1),
|
||||
("eagle3", "meta-llama/Llama-3.1-8B-Instruct",
|
||||
"yuhuili/EAGLE3-LLaMA3.1-Instruct-8B", 1),
|
||||
pytest.param(
|
||||
("eagle", "meta-llama/Llama-4-Scout-17B-16E-Instruct",
|
||||
"morgendave/EAGLE-Llama-4-Scout-17B-16E-Instruct", 4),
|
||||
marks=pytest.mark.skip(reason="Skipping due to CI OOM issues")),
|
||||
],
|
||||
ids=["llama3_eagle", "llama3_eagle3", "llama4_eagle"])
|
||||
def test_eagle_correctness(
|
||||
monkeypatch: pytest.MonkeyPatch,
|
||||
test_prompts: list[list[dict[str, Any]]],
|
||||
sampling_config: SamplingParams,
|
||||
model_name: str,
|
||||
use_eagle3: bool,
|
||||
model_setup: tuple[str, str, str, int],
|
||||
):
|
||||
'''
|
||||
Compare the outputs of a original LLM and a speculative LLM
|
||||
should be the same when using eagle speculative decoding.
|
||||
model_setup: (method, model_name, eagle_model_name, tp_size)
|
||||
'''
|
||||
with monkeypatch.context() as m:
|
||||
m.setenv("VLLM_USE_V1", "1")
|
||||
method, model_name, spec_model_name, tp_size = model_setup
|
||||
|
||||
ref_llm = LLM(model=model_name, max_model_len=2048)
|
||||
ref_llm = LLM(model=model_name,
|
||||
max_model_len=2048,
|
||||
tensor_parallel_size=tp_size)
|
||||
ref_outputs = ref_llm.chat(test_prompts, sampling_config)
|
||||
del ref_llm
|
||||
torch.cuda.empty_cache()
|
||||
cleanup_dist_env_and_memory()
|
||||
|
||||
spec_model_name = eagle3_model_name(
|
||||
) if use_eagle3 else eagle_model_name()
|
||||
spec_llm = LLM(
|
||||
model=model_name,
|
||||
trust_remote_code=True,
|
||||
tensor_parallel_size=tp_size,
|
||||
speculative_config={
|
||||
"method": "eagle3" if use_eagle3 else "eagle",
|
||||
"method": method,
|
||||
"model": spec_model_name,
|
||||
"num_speculative_tokens": 3,
|
||||
"max_model_len": 2048,
|
||||
@ -152,3 +164,5 @@ def test_eagle_correctness(
|
||||
# Upon failure, inspect the outputs to check for inaccuracy.
|
||||
assert matches > int(0.66 * len(ref_outputs))
|
||||
del spec_llm
|
||||
torch.cuda.empty_cache()
|
||||
cleanup_dist_env_and_memory()
|
||||
|
||||
214
vllm/model_executor/models/llama4_eagle.py
Normal file
214
vllm/model_executor/models/llama4_eagle.py
Normal file
@ -0,0 +1,214 @@
|
||||
# SPDX-License-Identifier: Apache-2.0
|
||||
# SPDX-FileCopyrightText: Copyright contributors to the vLLM project
|
||||
# Copyright 2025 the LLAMA4, Meta Inc., vLLM, and HuggingFace Inc. team.
|
||||
# All rights reserved.
|
||||
#
|
||||
#
|
||||
# Licensed under the Apache License, Version 2.0 (the "License");
|
||||
# you may not use this file except in compliance with the License.
|
||||
# You may obtain a copy of the License at
|
||||
#
|
||||
# http://www.apache.org/licenses/LICENSE-2.0
|
||||
#
|
||||
# Unless required by applicable law or agreed to in writing, software
|
||||
# distributed under the License is distributed on an "AS IS" BASIS,
|
||||
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
# See the License for the specific language governing permissions and
|
||||
# limitations under the License.
|
||||
|
||||
from collections.abc import Iterable
|
||||
from typing import Optional
|
||||
|
||||
import torch
|
||||
import torch.nn as nn
|
||||
|
||||
from vllm.compilation.decorators import support_torch_compile
|
||||
from vllm.config import VllmConfig
|
||||
from vllm.distributed.parallel_state import get_pp_group
|
||||
from vllm.logger import init_logger
|
||||
from vllm.model_executor.layers.layernorm import RMSNorm
|
||||
from vllm.model_executor.layers.logits_processor import LogitsProcessor
|
||||
from vllm.model_executor.layers.quantization.base_config import (
|
||||
QuantizationConfig)
|
||||
from vllm.model_executor.layers.quantization.torchao import TorchAOConfig
|
||||
from vllm.model_executor.layers.vocab_parallel_embedding import (
|
||||
VocabParallelEmbedding)
|
||||
from vllm.model_executor.model_loader.weight_utils import default_weight_loader
|
||||
from vllm.model_executor.models.llama4 import (Llama4DecoderLayer,
|
||||
Llama4ForCausalLM)
|
||||
from vllm.model_executor.models.utils import extract_layer_index
|
||||
|
||||
from .utils import AutoWeightsLoader, maybe_prefix
|
||||
|
||||
logger = init_logger(__name__)
|
||||
|
||||
|
||||
@support_torch_compile
|
||||
class LlamaModel(nn.Module):
|
||||
|
||||
def __init__(
|
||||
self,
|
||||
*,
|
||||
vllm_config: VllmConfig,
|
||||
prefix: str = "",
|
||||
start_layer_id: int = 0,
|
||||
quant_config: Optional[QuantizationConfig] = None,
|
||||
) -> None:
|
||||
super().__init__()
|
||||
self.config = (
|
||||
vllm_config.speculative_config.draft_model_config.hf_config)
|
||||
self.validate_and_update_config(start_layer_id, quant_config)
|
||||
self.vocab_size = self.config.vocab_size
|
||||
self.embed_tokens = VocabParallelEmbedding(
|
||||
self.config.vocab_size,
|
||||
self.config.hidden_size,
|
||||
prefix=maybe_prefix(prefix, "embed_tokens"),
|
||||
)
|
||||
|
||||
self.layers = nn.ModuleList([
|
||||
Llama4DecoderLayer(
|
||||
self.config,
|
||||
quant_config=quant_config,
|
||||
prefix=maybe_prefix(prefix, f"layers.{i + start_layer_id}"),
|
||||
) for i in range(self.config.num_hidden_layers)
|
||||
])
|
||||
self.fc = torch.nn.Linear(self.config.hidden_size * 2,
|
||||
self.config.hidden_size,
|
||||
bias=False)
|
||||
self.norm = RMSNorm(self.config.hidden_size,
|
||||
eps=self.config.rms_norm_eps)
|
||||
|
||||
def forward(
|
||||
self,
|
||||
input_ids: Optional[torch.Tensor],
|
||||
positions: torch.Tensor,
|
||||
hidden_states: torch.Tensor,
|
||||
) -> tuple[torch.Tensor, torch.Tensor]:
|
||||
input_embeds = self.embed_tokens(input_ids)
|
||||
hidden_states = self.fc(
|
||||
torch.cat((input_embeds, hidden_states), dim=-1))
|
||||
residual = None
|
||||
for layer in self.layers:
|
||||
hidden_states, residual = layer(
|
||||
positions,
|
||||
hidden_states,
|
||||
residual,
|
||||
)
|
||||
hidden_states, _ = self.norm(hidden_states, residual)
|
||||
return hidden_states, hidden_states
|
||||
|
||||
def load_weights(self, weights: Iterable[tuple[str,
|
||||
torch.Tensor]]) -> set[str]:
|
||||
stacked_params_mapping = [
|
||||
# (param_name, shard_name, shard_id)
|
||||
(".qkv_proj", ".q_proj", "q"),
|
||||
(".qkv_proj", ".k_proj", "k"),
|
||||
(".qkv_proj", ".v_proj", "v"),
|
||||
(".gate_up_proj", ".gate_proj", 0),
|
||||
(".gate_up_proj", ".up_proj", 1),
|
||||
]
|
||||
params_dict = dict(self.named_parameters())
|
||||
loaded_params: set[str] = set()
|
||||
for name, loaded_weight in weights:
|
||||
name = name.removeprefix("model.")
|
||||
for param_name, weight_name, shard_id in stacked_params_mapping:
|
||||
if weight_name not in name:
|
||||
continue
|
||||
name = name.replace(weight_name, param_name)
|
||||
param = params_dict[name]
|
||||
weight_loader = param.weight_loader
|
||||
weight_loader(param, loaded_weight, shard_id)
|
||||
break
|
||||
else:
|
||||
# if PP disabled then draft will share embed with target
|
||||
if get_pp_group().world_size == 1 and \
|
||||
"embed_tokens." in name:
|
||||
continue
|
||||
param = params_dict[name]
|
||||
weight_loader = getattr(param, "weight_loader",
|
||||
default_weight_loader)
|
||||
weight_loader(param, loaded_weight)
|
||||
loaded_params.add(name)
|
||||
for name in params_dict:
|
||||
# if PP disabled then draft will share embed with target
|
||||
if get_pp_group().world_size == 1 and \
|
||||
"embed_tokens." in name:
|
||||
continue
|
||||
assert name in loaded_params, f"{name} is not loaded!"
|
||||
return loaded_params
|
||||
|
||||
def validate_and_update_config(
|
||||
self,
|
||||
start_layer_id: int,
|
||||
quant_config: Optional[QuantizationConfig] = None) -> None:
|
||||
# yoco and moe is not supported by draft model yet
|
||||
assert self.config.yoco_global_kv_layer is None
|
||||
assert self.config.yoco_local_kv_layer is None
|
||||
assert len(self.config.moe_layers) == 0
|
||||
# draft model layer index is increased by start_layer_id,
|
||||
# so we need to pad relevant configs accordingly
|
||||
self.config.no_rope_layers = [
|
||||
0
|
||||
] * start_layer_id + self.config.no_rope_layers
|
||||
# currently only TorchAO quantization is supported
|
||||
if isinstance(quant_config, TorchAOConfig):
|
||||
|
||||
def pad_layer_name(layer: str) -> str:
|
||||
layer_index = extract_layer_index(layer)
|
||||
return layer.replace(str(layer_index),
|
||||
str(layer_index + start_layer_id))
|
||||
|
||||
quant_config.torchao_config.module_fqn_to_config = {
|
||||
pad_layer_name(layer): quantization
|
||||
for layer, quantization in
|
||||
quant_config.torchao_config.module_fqn_to_config.items()
|
||||
}
|
||||
|
||||
|
||||
class EagleLlama4ForCausalLM(Llama4ForCausalLM):
|
||||
|
||||
def __init__(self, *, vllm_config: VllmConfig, prefix: str = ""):
|
||||
nn.Module.__init__(self)
|
||||
self.config = (
|
||||
vllm_config.speculative_config.draft_model_config.hf_config)
|
||||
target_layer_num = vllm_config.model_config.get_num_layers(
|
||||
vllm_config.parallel_config)
|
||||
# draft model quantization config may differ from target model
|
||||
quant_config = VllmConfig.get_quantization_config(
|
||||
vllm_config.speculative_config.draft_model_config,
|
||||
vllm_config.load_config)
|
||||
self.model = LlamaModel(vllm_config=vllm_config,
|
||||
prefix="model",
|
||||
start_layer_id=target_layer_num,
|
||||
quant_config=quant_config)
|
||||
logit_scale = getattr(self.config, "logit_scale", 1.0)
|
||||
self.logits_processor = LogitsProcessor(self.config.vocab_size,
|
||||
scale=logit_scale)
|
||||
|
||||
def forward(
|
||||
self,
|
||||
input_ids: torch.Tensor,
|
||||
positions: torch.Tensor,
|
||||
hidden_states: torch.Tensor,
|
||||
) -> tuple[torch.Tensor, torch.Tensor]:
|
||||
return self.model(input_ids, positions, hidden_states)
|
||||
|
||||
def load_weights(self, weights: Iterable[tuple[str,
|
||||
torch.Tensor]]) -> None:
|
||||
loader = AutoWeightsLoader(
|
||||
self,
|
||||
# lm_head is tied with target model (Llama4ForCausalLM)
|
||||
skip_prefixes=(["lm_head."]),
|
||||
)
|
||||
|
||||
model_weights = {}
|
||||
weights = [
|
||||
self.permute_qk_weight_for_rotary(name, loaded_weight)
|
||||
for name, loaded_weight in weights
|
||||
]
|
||||
for name, loaded_weight in weights:
|
||||
if "lm_head" not in name:
|
||||
name = "model." + name
|
||||
model_weights[name] = loaded_weight
|
||||
|
||||
loader.load_weights(model_weights.items())
|
||||
@ -244,6 +244,7 @@ _SPECULATIVE_DECODING_MODELS = {
|
||||
"MiMoMTPModel": ("mimo_mtp", "MiMoMTP"),
|
||||
"EAGLEModel": ("eagle", "EAGLE"),
|
||||
"EagleLlamaForCausalLM": ("llama_eagle", "EagleLlamaForCausalLM"),
|
||||
"EagleLlama4ForCausalLM": ("llama4_eagle", "EagleLlama4ForCausalLM"),
|
||||
"EagleMiniCPMForCausalLM": ("minicpm_eagle", "EagleMiniCPMForCausalLM"),
|
||||
"Eagle3LlamaForCausalLM": ("llama_eagle3", "Eagle3LlamaForCausalLM"),
|
||||
"DeepSeekMTPModel": ("deepseek_mtp", "DeepSeekMTP"),
|
||||
|
||||
Loading…
x
Reference in New Issue
Block a user