mirror of
https://git.datalinker.icu/vllm-project/vllm.git
synced 2025-12-14 18:55:01 +08:00
[Meta] Llama4 EAGLE Support (#20591)
Signed-off-by: qizixi <qizixi@meta.com> Co-authored-by: qizixi <qizixi@meta.com>
This commit is contained in:
parent
1eb2b9c102
commit
c11013db8b
@ -84,6 +84,7 @@ def main():
|
|||||||
gpu_memory_utilization=0.8,
|
gpu_memory_utilization=0.8,
|
||||||
speculative_config=speculative_config,
|
speculative_config=speculative_config,
|
||||||
disable_log_stats=False,
|
disable_log_stats=False,
|
||||||
|
max_model_len=16384,
|
||||||
)
|
)
|
||||||
|
|
||||||
sampling_params = SamplingParams(temperature=args.temp, max_tokens=args.output_len)
|
sampling_params = SamplingParams(temperature=args.temp, max_tokens=args.output_len)
|
||||||
|
|||||||
@ -465,6 +465,11 @@ _SPECULATIVE_DECODING_EXAMPLE_MODELS = {
|
|||||||
trust_remote_code=True,
|
trust_remote_code=True,
|
||||||
speculative_model="yuhuili/EAGLE3-LLaMA3.1-Instruct-8B",
|
speculative_model="yuhuili/EAGLE3-LLaMA3.1-Instruct-8B",
|
||||||
tokenizer="meta-llama/Llama-3.1-8B-Instruct"),
|
tokenizer="meta-llama/Llama-3.1-8B-Instruct"),
|
||||||
|
"EagleLlama4ForCausalLM": _HfExamplesInfo(
|
||||||
|
"morgendave/EAGLE-Llama-4-Scout-17B-16E-Instruct",
|
||||||
|
trust_remote_code=True,
|
||||||
|
speculative_model="morgendave/EAGLE-Llama-4-Scout-17B-16E-Instruct",
|
||||||
|
tokenizer="meta-llama/Llama-4-Scout-17B-16E-Instruct"), # noqa: E501
|
||||||
"EagleMiniCPMForCausalLM": _HfExamplesInfo("openbmb/MiniCPM-1B-sft-bf16",
|
"EagleMiniCPMForCausalLM": _HfExamplesInfo("openbmb/MiniCPM-1B-sft-bf16",
|
||||||
trust_remote_code=True,
|
trust_remote_code=True,
|
||||||
is_available_online=False,
|
is_available_online=False,
|
||||||
|
|||||||
@ -36,6 +36,11 @@ def test_can_initialize(model_arch: str, monkeypatch: pytest.MonkeyPatch):
|
|||||||
"KimiVLForConditionalGeneration"):
|
"KimiVLForConditionalGeneration"):
|
||||||
pytest.skip("Avoid OOM")
|
pytest.skip("Avoid OOM")
|
||||||
|
|
||||||
|
if model_arch in ("Llama4ForCausalLM", "EagleLlama4ForCausalLM"):
|
||||||
|
from vllm.model_executor.models.llama4 import Llama4ForCausalLM
|
||||||
|
from vllm.model_executor.models.registry import ModelRegistry
|
||||||
|
ModelRegistry.register_model("Llama4ForCausalLM", Llama4ForCausalLM)
|
||||||
|
|
||||||
# Avoid OOM and reduce initialization time by only using 1 layer
|
# Avoid OOM and reduce initialization time by only using 1 layer
|
||||||
def hf_overrides(hf_config: PretrainedConfig) -> PretrainedConfig:
|
def hf_overrides(hf_config: PretrainedConfig) -> PretrainedConfig:
|
||||||
hf_config.update(model_info.hf_overrides)
|
hf_config.update(model_info.hf_overrides)
|
||||||
|
|||||||
@ -6,8 +6,10 @@ import random
|
|||||||
from typing import Any
|
from typing import Any
|
||||||
|
|
||||||
import pytest
|
import pytest
|
||||||
|
import torch
|
||||||
|
|
||||||
from vllm import LLM, SamplingParams
|
from vllm import LLM, SamplingParams
|
||||||
|
from vllm.distributed import cleanup_dist_env_and_memory
|
||||||
|
|
||||||
|
|
||||||
@pytest.fixture
|
@pytest.fixture
|
||||||
@ -53,14 +55,6 @@ def model_name():
|
|||||||
return "meta-llama/Llama-3.1-8B-Instruct"
|
return "meta-llama/Llama-3.1-8B-Instruct"
|
||||||
|
|
||||||
|
|
||||||
def eagle_model_name():
|
|
||||||
return "yuhuili/EAGLE-LLaMA3.1-Instruct-8B"
|
|
||||||
|
|
||||||
|
|
||||||
def eagle3_model_name():
|
|
||||||
return "yuhuili/EAGLE3-LLaMA3.1-Instruct-8B"
|
|
||||||
|
|
||||||
|
|
||||||
def test_ngram_correctness(
|
def test_ngram_correctness(
|
||||||
monkeypatch: pytest.MonkeyPatch,
|
monkeypatch: pytest.MonkeyPatch,
|
||||||
test_prompts: list[list[dict[str, Any]]],
|
test_prompts: list[list[dict[str, Any]]],
|
||||||
@ -77,6 +71,8 @@ def test_ngram_correctness(
|
|||||||
ref_llm = LLM(model=model_name, max_model_len=1024)
|
ref_llm = LLM(model=model_name, max_model_len=1024)
|
||||||
ref_outputs = ref_llm.chat(test_prompts, sampling_config)
|
ref_outputs = ref_llm.chat(test_prompts, sampling_config)
|
||||||
del ref_llm
|
del ref_llm
|
||||||
|
torch.cuda.empty_cache()
|
||||||
|
cleanup_dist_env_and_memory()
|
||||||
|
|
||||||
spec_llm = LLM(
|
spec_llm = LLM(
|
||||||
model=model_name,
|
model=model_name,
|
||||||
@ -103,34 +99,50 @@ def test_ngram_correctness(
|
|||||||
# Upon failure, inspect the outputs to check for inaccuracy.
|
# Upon failure, inspect the outputs to check for inaccuracy.
|
||||||
assert matches > int(0.7 * len(ref_outputs))
|
assert matches > int(0.7 * len(ref_outputs))
|
||||||
del spec_llm
|
del spec_llm
|
||||||
|
torch.cuda.empty_cache()
|
||||||
|
cleanup_dist_env_and_memory()
|
||||||
|
|
||||||
|
|
||||||
@pytest.mark.parametrize("use_eagle3", [False, True], ids=["eagle", "eagle3"])
|
@pytest.mark.parametrize("model_setup", [
|
||||||
|
("eagle", "meta-llama/Llama-3.1-8B-Instruct",
|
||||||
|
"yuhuili/EAGLE-LLaMA3.1-Instruct-8B", 1),
|
||||||
|
("eagle3", "meta-llama/Llama-3.1-8B-Instruct",
|
||||||
|
"yuhuili/EAGLE3-LLaMA3.1-Instruct-8B", 1),
|
||||||
|
pytest.param(
|
||||||
|
("eagle", "meta-llama/Llama-4-Scout-17B-16E-Instruct",
|
||||||
|
"morgendave/EAGLE-Llama-4-Scout-17B-16E-Instruct", 4),
|
||||||
|
marks=pytest.mark.skip(reason="Skipping due to CI OOM issues")),
|
||||||
|
],
|
||||||
|
ids=["llama3_eagle", "llama3_eagle3", "llama4_eagle"])
|
||||||
def test_eagle_correctness(
|
def test_eagle_correctness(
|
||||||
monkeypatch: pytest.MonkeyPatch,
|
monkeypatch: pytest.MonkeyPatch,
|
||||||
test_prompts: list[list[dict[str, Any]]],
|
test_prompts: list[list[dict[str, Any]]],
|
||||||
sampling_config: SamplingParams,
|
sampling_config: SamplingParams,
|
||||||
model_name: str,
|
model_setup: tuple[str, str, str, int],
|
||||||
use_eagle3: bool,
|
|
||||||
):
|
):
|
||||||
'''
|
'''
|
||||||
Compare the outputs of a original LLM and a speculative LLM
|
Compare the outputs of a original LLM and a speculative LLM
|
||||||
should be the same when using eagle speculative decoding.
|
should be the same when using eagle speculative decoding.
|
||||||
|
model_setup: (method, model_name, eagle_model_name, tp_size)
|
||||||
'''
|
'''
|
||||||
with monkeypatch.context() as m:
|
with monkeypatch.context() as m:
|
||||||
m.setenv("VLLM_USE_V1", "1")
|
m.setenv("VLLM_USE_V1", "1")
|
||||||
|
method, model_name, spec_model_name, tp_size = model_setup
|
||||||
|
|
||||||
ref_llm = LLM(model=model_name, max_model_len=2048)
|
ref_llm = LLM(model=model_name,
|
||||||
|
max_model_len=2048,
|
||||||
|
tensor_parallel_size=tp_size)
|
||||||
ref_outputs = ref_llm.chat(test_prompts, sampling_config)
|
ref_outputs = ref_llm.chat(test_prompts, sampling_config)
|
||||||
del ref_llm
|
del ref_llm
|
||||||
|
torch.cuda.empty_cache()
|
||||||
|
cleanup_dist_env_and_memory()
|
||||||
|
|
||||||
spec_model_name = eagle3_model_name(
|
|
||||||
) if use_eagle3 else eagle_model_name()
|
|
||||||
spec_llm = LLM(
|
spec_llm = LLM(
|
||||||
model=model_name,
|
model=model_name,
|
||||||
trust_remote_code=True,
|
trust_remote_code=True,
|
||||||
|
tensor_parallel_size=tp_size,
|
||||||
speculative_config={
|
speculative_config={
|
||||||
"method": "eagle3" if use_eagle3 else "eagle",
|
"method": method,
|
||||||
"model": spec_model_name,
|
"model": spec_model_name,
|
||||||
"num_speculative_tokens": 3,
|
"num_speculative_tokens": 3,
|
||||||
"max_model_len": 2048,
|
"max_model_len": 2048,
|
||||||
@ -152,3 +164,5 @@ def test_eagle_correctness(
|
|||||||
# Upon failure, inspect the outputs to check for inaccuracy.
|
# Upon failure, inspect the outputs to check for inaccuracy.
|
||||||
assert matches > int(0.66 * len(ref_outputs))
|
assert matches > int(0.66 * len(ref_outputs))
|
||||||
del spec_llm
|
del spec_llm
|
||||||
|
torch.cuda.empty_cache()
|
||||||
|
cleanup_dist_env_and_memory()
|
||||||
|
|||||||
214
vllm/model_executor/models/llama4_eagle.py
Normal file
214
vllm/model_executor/models/llama4_eagle.py
Normal file
@ -0,0 +1,214 @@
|
|||||||
|
# SPDX-License-Identifier: Apache-2.0
|
||||||
|
# SPDX-FileCopyrightText: Copyright contributors to the vLLM project
|
||||||
|
# Copyright 2025 the LLAMA4, Meta Inc., vLLM, and HuggingFace Inc. team.
|
||||||
|
# All rights reserved.
|
||||||
|
#
|
||||||
|
#
|
||||||
|
# Licensed under the Apache License, Version 2.0 (the "License");
|
||||||
|
# you may not use this file except in compliance with the License.
|
||||||
|
# You may obtain a copy of the License at
|
||||||
|
#
|
||||||
|
# http://www.apache.org/licenses/LICENSE-2.0
|
||||||
|
#
|
||||||
|
# Unless required by applicable law or agreed to in writing, software
|
||||||
|
# distributed under the License is distributed on an "AS IS" BASIS,
|
||||||
|
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||||
|
# See the License for the specific language governing permissions and
|
||||||
|
# limitations under the License.
|
||||||
|
|
||||||
|
from collections.abc import Iterable
|
||||||
|
from typing import Optional
|
||||||
|
|
||||||
|
import torch
|
||||||
|
import torch.nn as nn
|
||||||
|
|
||||||
|
from vllm.compilation.decorators import support_torch_compile
|
||||||
|
from vllm.config import VllmConfig
|
||||||
|
from vllm.distributed.parallel_state import get_pp_group
|
||||||
|
from vllm.logger import init_logger
|
||||||
|
from vllm.model_executor.layers.layernorm import RMSNorm
|
||||||
|
from vllm.model_executor.layers.logits_processor import LogitsProcessor
|
||||||
|
from vllm.model_executor.layers.quantization.base_config import (
|
||||||
|
QuantizationConfig)
|
||||||
|
from vllm.model_executor.layers.quantization.torchao import TorchAOConfig
|
||||||
|
from vllm.model_executor.layers.vocab_parallel_embedding import (
|
||||||
|
VocabParallelEmbedding)
|
||||||
|
from vllm.model_executor.model_loader.weight_utils import default_weight_loader
|
||||||
|
from vllm.model_executor.models.llama4 import (Llama4DecoderLayer,
|
||||||
|
Llama4ForCausalLM)
|
||||||
|
from vllm.model_executor.models.utils import extract_layer_index
|
||||||
|
|
||||||
|
from .utils import AutoWeightsLoader, maybe_prefix
|
||||||
|
|
||||||
|
logger = init_logger(__name__)
|
||||||
|
|
||||||
|
|
||||||
|
@support_torch_compile
|
||||||
|
class LlamaModel(nn.Module):
|
||||||
|
|
||||||
|
def __init__(
|
||||||
|
self,
|
||||||
|
*,
|
||||||
|
vllm_config: VllmConfig,
|
||||||
|
prefix: str = "",
|
||||||
|
start_layer_id: int = 0,
|
||||||
|
quant_config: Optional[QuantizationConfig] = None,
|
||||||
|
) -> None:
|
||||||
|
super().__init__()
|
||||||
|
self.config = (
|
||||||
|
vllm_config.speculative_config.draft_model_config.hf_config)
|
||||||
|
self.validate_and_update_config(start_layer_id, quant_config)
|
||||||
|
self.vocab_size = self.config.vocab_size
|
||||||
|
self.embed_tokens = VocabParallelEmbedding(
|
||||||
|
self.config.vocab_size,
|
||||||
|
self.config.hidden_size,
|
||||||
|
prefix=maybe_prefix(prefix, "embed_tokens"),
|
||||||
|
)
|
||||||
|
|
||||||
|
self.layers = nn.ModuleList([
|
||||||
|
Llama4DecoderLayer(
|
||||||
|
self.config,
|
||||||
|
quant_config=quant_config,
|
||||||
|
prefix=maybe_prefix(prefix, f"layers.{i + start_layer_id}"),
|
||||||
|
) for i in range(self.config.num_hidden_layers)
|
||||||
|
])
|
||||||
|
self.fc = torch.nn.Linear(self.config.hidden_size * 2,
|
||||||
|
self.config.hidden_size,
|
||||||
|
bias=False)
|
||||||
|
self.norm = RMSNorm(self.config.hidden_size,
|
||||||
|
eps=self.config.rms_norm_eps)
|
||||||
|
|
||||||
|
def forward(
|
||||||
|
self,
|
||||||
|
input_ids: Optional[torch.Tensor],
|
||||||
|
positions: torch.Tensor,
|
||||||
|
hidden_states: torch.Tensor,
|
||||||
|
) -> tuple[torch.Tensor, torch.Tensor]:
|
||||||
|
input_embeds = self.embed_tokens(input_ids)
|
||||||
|
hidden_states = self.fc(
|
||||||
|
torch.cat((input_embeds, hidden_states), dim=-1))
|
||||||
|
residual = None
|
||||||
|
for layer in self.layers:
|
||||||
|
hidden_states, residual = layer(
|
||||||
|
positions,
|
||||||
|
hidden_states,
|
||||||
|
residual,
|
||||||
|
)
|
||||||
|
hidden_states, _ = self.norm(hidden_states, residual)
|
||||||
|
return hidden_states, hidden_states
|
||||||
|
|
||||||
|
def load_weights(self, weights: Iterable[tuple[str,
|
||||||
|
torch.Tensor]]) -> set[str]:
|
||||||
|
stacked_params_mapping = [
|
||||||
|
# (param_name, shard_name, shard_id)
|
||||||
|
(".qkv_proj", ".q_proj", "q"),
|
||||||
|
(".qkv_proj", ".k_proj", "k"),
|
||||||
|
(".qkv_proj", ".v_proj", "v"),
|
||||||
|
(".gate_up_proj", ".gate_proj", 0),
|
||||||
|
(".gate_up_proj", ".up_proj", 1),
|
||||||
|
]
|
||||||
|
params_dict = dict(self.named_parameters())
|
||||||
|
loaded_params: set[str] = set()
|
||||||
|
for name, loaded_weight in weights:
|
||||||
|
name = name.removeprefix("model.")
|
||||||
|
for param_name, weight_name, shard_id in stacked_params_mapping:
|
||||||
|
if weight_name not in name:
|
||||||
|
continue
|
||||||
|
name = name.replace(weight_name, param_name)
|
||||||
|
param = params_dict[name]
|
||||||
|
weight_loader = param.weight_loader
|
||||||
|
weight_loader(param, loaded_weight, shard_id)
|
||||||
|
break
|
||||||
|
else:
|
||||||
|
# if PP disabled then draft will share embed with target
|
||||||
|
if get_pp_group().world_size == 1 and \
|
||||||
|
"embed_tokens." in name:
|
||||||
|
continue
|
||||||
|
param = params_dict[name]
|
||||||
|
weight_loader = getattr(param, "weight_loader",
|
||||||
|
default_weight_loader)
|
||||||
|
weight_loader(param, loaded_weight)
|
||||||
|
loaded_params.add(name)
|
||||||
|
for name in params_dict:
|
||||||
|
# if PP disabled then draft will share embed with target
|
||||||
|
if get_pp_group().world_size == 1 and \
|
||||||
|
"embed_tokens." in name:
|
||||||
|
continue
|
||||||
|
assert name in loaded_params, f"{name} is not loaded!"
|
||||||
|
return loaded_params
|
||||||
|
|
||||||
|
def validate_and_update_config(
|
||||||
|
self,
|
||||||
|
start_layer_id: int,
|
||||||
|
quant_config: Optional[QuantizationConfig] = None) -> None:
|
||||||
|
# yoco and moe is not supported by draft model yet
|
||||||
|
assert self.config.yoco_global_kv_layer is None
|
||||||
|
assert self.config.yoco_local_kv_layer is None
|
||||||
|
assert len(self.config.moe_layers) == 0
|
||||||
|
# draft model layer index is increased by start_layer_id,
|
||||||
|
# so we need to pad relevant configs accordingly
|
||||||
|
self.config.no_rope_layers = [
|
||||||
|
0
|
||||||
|
] * start_layer_id + self.config.no_rope_layers
|
||||||
|
# currently only TorchAO quantization is supported
|
||||||
|
if isinstance(quant_config, TorchAOConfig):
|
||||||
|
|
||||||
|
def pad_layer_name(layer: str) -> str:
|
||||||
|
layer_index = extract_layer_index(layer)
|
||||||
|
return layer.replace(str(layer_index),
|
||||||
|
str(layer_index + start_layer_id))
|
||||||
|
|
||||||
|
quant_config.torchao_config.module_fqn_to_config = {
|
||||||
|
pad_layer_name(layer): quantization
|
||||||
|
for layer, quantization in
|
||||||
|
quant_config.torchao_config.module_fqn_to_config.items()
|
||||||
|
}
|
||||||
|
|
||||||
|
|
||||||
|
class EagleLlama4ForCausalLM(Llama4ForCausalLM):
|
||||||
|
|
||||||
|
def __init__(self, *, vllm_config: VllmConfig, prefix: str = ""):
|
||||||
|
nn.Module.__init__(self)
|
||||||
|
self.config = (
|
||||||
|
vllm_config.speculative_config.draft_model_config.hf_config)
|
||||||
|
target_layer_num = vllm_config.model_config.get_num_layers(
|
||||||
|
vllm_config.parallel_config)
|
||||||
|
# draft model quantization config may differ from target model
|
||||||
|
quant_config = VllmConfig.get_quantization_config(
|
||||||
|
vllm_config.speculative_config.draft_model_config,
|
||||||
|
vllm_config.load_config)
|
||||||
|
self.model = LlamaModel(vllm_config=vllm_config,
|
||||||
|
prefix="model",
|
||||||
|
start_layer_id=target_layer_num,
|
||||||
|
quant_config=quant_config)
|
||||||
|
logit_scale = getattr(self.config, "logit_scale", 1.0)
|
||||||
|
self.logits_processor = LogitsProcessor(self.config.vocab_size,
|
||||||
|
scale=logit_scale)
|
||||||
|
|
||||||
|
def forward(
|
||||||
|
self,
|
||||||
|
input_ids: torch.Tensor,
|
||||||
|
positions: torch.Tensor,
|
||||||
|
hidden_states: torch.Tensor,
|
||||||
|
) -> tuple[torch.Tensor, torch.Tensor]:
|
||||||
|
return self.model(input_ids, positions, hidden_states)
|
||||||
|
|
||||||
|
def load_weights(self, weights: Iterable[tuple[str,
|
||||||
|
torch.Tensor]]) -> None:
|
||||||
|
loader = AutoWeightsLoader(
|
||||||
|
self,
|
||||||
|
# lm_head is tied with target model (Llama4ForCausalLM)
|
||||||
|
skip_prefixes=(["lm_head."]),
|
||||||
|
)
|
||||||
|
|
||||||
|
model_weights = {}
|
||||||
|
weights = [
|
||||||
|
self.permute_qk_weight_for_rotary(name, loaded_weight)
|
||||||
|
for name, loaded_weight in weights
|
||||||
|
]
|
||||||
|
for name, loaded_weight in weights:
|
||||||
|
if "lm_head" not in name:
|
||||||
|
name = "model." + name
|
||||||
|
model_weights[name] = loaded_weight
|
||||||
|
|
||||||
|
loader.load_weights(model_weights.items())
|
||||||
@ -244,6 +244,7 @@ _SPECULATIVE_DECODING_MODELS = {
|
|||||||
"MiMoMTPModel": ("mimo_mtp", "MiMoMTP"),
|
"MiMoMTPModel": ("mimo_mtp", "MiMoMTP"),
|
||||||
"EAGLEModel": ("eagle", "EAGLE"),
|
"EAGLEModel": ("eagle", "EAGLE"),
|
||||||
"EagleLlamaForCausalLM": ("llama_eagle", "EagleLlamaForCausalLM"),
|
"EagleLlamaForCausalLM": ("llama_eagle", "EagleLlamaForCausalLM"),
|
||||||
|
"EagleLlama4ForCausalLM": ("llama4_eagle", "EagleLlama4ForCausalLM"),
|
||||||
"EagleMiniCPMForCausalLM": ("minicpm_eagle", "EagleMiniCPMForCausalLM"),
|
"EagleMiniCPMForCausalLM": ("minicpm_eagle", "EagleMiniCPMForCausalLM"),
|
||||||
"Eagle3LlamaForCausalLM": ("llama_eagle3", "Eagle3LlamaForCausalLM"),
|
"Eagle3LlamaForCausalLM": ("llama_eagle3", "Eagle3LlamaForCausalLM"),
|
||||||
"DeepSeekMTPModel": ("deepseek_mtp", "DeepSeekMTP"),
|
"DeepSeekMTPModel": ("deepseek_mtp", "DeepSeekMTP"),
|
||||||
|
|||||||
Loading…
x
Reference in New Issue
Block a user