[Meta] Official Eagle mm support, first enablement on llama4 (#20788)

Signed-off-by: morgendave <morgendave@gmail.com>
Co-authored-by: Roger Wang <hey@rogerw.me>
This commit is contained in:
zhiweiz 2025-07-31 10:35:07 -07:00 committed by GitHub
parent 53c21e492e
commit 9e0726e5bf
No known key found for this signature in database
GPG Key ID: B5690EEEBB952194
8 changed files with 205 additions and 36 deletions

View File

@ -13,6 +13,38 @@ except ImportError:
from argparse import ArgumentParser as FlexibleArgumentParser
QUESTION = "What is the content of each image?"
IMAGE_URLS = [
"https://upload.wikimedia.org/wikipedia/commons/d/da/2015_Kaczka_krzy%C5%BCowka_w_wodzie_%28samiec%29.jpg",
"https://upload.wikimedia.org/wikipedia/commons/7/77/002_The_lion_king_Snyggve_in_the_Serengeti_National_Park_Photo_by_Giles_Laurent.jpg",
"https://upload.wikimedia.org/wikipedia/commons/2/26/Ultramarine_Flycatcher_%28Ficedula_superciliaris%29_Naggar%2C_Himachal_Pradesh%2C_2013_%28cropped%29.JPG",
"https://upload.wikimedia.org/wikipedia/commons/thumb/e/e5/Anim1754_-_Flickr_-_NOAA_Photo_Library_%281%29.jpg/2560px-Anim1754_-_Flickr_-_NOAA_Photo_Library_%281%29.jpg",
"https://upload.wikimedia.org/wikipedia/commons/d/d4/Starfish%2C_Caswell_Bay_-_geograph.org.uk_-_409413.jpg",
"https://upload.wikimedia.org/wikipedia/commons/6/69/Grapevinesnail_01.jpg",
"https://upload.wikimedia.org/wikipedia/commons/thumb/0/0b/Texas_invasive_Musk_Thistle_1.jpg/1920px-Texas_invasive_Musk_Thistle_1.jpg",
"https://upload.wikimedia.org/wikipedia/commons/thumb/7/7a/Huskiesatrest.jpg/2880px-Huskiesatrest.jpg",
"https://upload.wikimedia.org/wikipedia/commons/thumb/6/68/Orange_tabby_cat_sitting_on_fallen_leaves-Hisashi-01A.jpg/1920px-Orange_tabby_cat_sitting_on_fallen_leaves-Hisashi-01A.jpg",
"https://upload.wikimedia.org/wikipedia/commons/3/30/George_the_amazing_guinea_pig.jpg",
"https://upload.wikimedia.org/wikipedia/commons/thumb/1/1f/Oryctolagus_cuniculus_Rcdo.jpg/1920px-Oryctolagus_cuniculus_Rcdo.jpg",
"https://upload.wikimedia.org/wikipedia/commons/9/98/Horse-and-pony.jpg",
]
def get_custom_mm_prompts(num_prompts):
prompts = []
for url in IMAGE_URLS:
prompts.append(
[
{"type": "image_url", "image_url": {"url": url}},
{"type": "text", "text": QUESTION},
]
)
if num_prompts > len(IMAGE_URLS):
prompts = prompts * (num_prompts // len(IMAGE_URLS) + 1)
return [[{"role": "user", "content": prompt}] for prompt in prompts[:num_prompts]]
def parse_args():
parser = FlexibleArgumentParser()
add_dataset_parser(parser)
@ -35,6 +67,7 @@ def parse_args():
parser.add_argument("--output-len", type=int, default=256)
parser.add_argument("--model-dir", type=str, default=None)
parser.add_argument("--eagle-dir", type=str, default=None)
parser.add_argument("--custom-mm-prompts", action="store_true")
return parser.parse_args()
@ -44,14 +77,26 @@ def main():
model_dir = args.model_dir
if args.model_dir is None:
if args.custom_mm_prompts:
raise ValueError(
"custom_mm_prompts requires mm based models"
"default llama3.1-8b-instruct is not mm based"
"please specify model_dir to give a mm based model"
)
model_dir = "meta-llama/Llama-3.1-8B-Instruct"
tokenizer = AutoTokenizer.from_pretrained(model_dir)
args.custom_skip_chat_template = True
prompts = get_samples(args, tokenizer)
# add_special_tokens is False to avoid adding bos twice when using chat templates
prompt_ids = [
tokenizer.encode(prompt.prompt, add_special_tokens=False) for prompt in prompts
]
if not args.custom_mm_prompts:
prompts = get_samples(args, tokenizer)
# add_special_tokens is False to avoid adding bos twice
# when using chat templates
prompt_ids = [
tokenizer.encode(prompt.prompt, add_special_tokens=False)
for prompt in prompts
]
else:
prompts = get_custom_mm_prompts(args.num_prompts)
if args.method == "eagle" or args.method == "eagle3":
eagle_dir = args.eagle_dir
@ -85,10 +130,17 @@ def main():
speculative_config=speculative_config,
disable_log_stats=False,
max_model_len=16384,
limit_mm_per_prompt={"image": 5},
disable_chunked_mm_input=True,
)
sampling_params = SamplingParams(temperature=args.temp, max_tokens=args.output_len)
outputs = llm.generate(prompt_token_ids=prompt_ids, sampling_params=sampling_params)
if not args.custom_mm_prompts:
outputs = llm.generate(
prompt_token_ids=prompt_ids, sampling_params=sampling_params
)
else:
outputs = llm.chat(prompts, sampling_params=sampling_params)
# print the generated text
if args.print_output:

View File

@ -3,29 +3,34 @@
from __future__ import annotations
import random
from typing import Any
from typing import Any, Union
import pytest
import torch
from vllm import LLM, SamplingParams
from vllm.assets.base import VLLM_S3_BUCKET_URL
from vllm.assets.image import VLM_IMAGES_DIR
from vllm.distributed import cleanup_dist_env_and_memory
@pytest.fixture
def test_prompts():
def get_test_prompts(mm_enabled: bool):
prompt_types = ["repeat", "sentence"]
if mm_enabled:
prompt_types.append("mm")
num_prompts = 100
prompts = []
random.seed(0)
random_prompt_type_choices = random.choices(prompt_types, k=num_prompts)
print(f"Prompt types: {random_prompt_type_choices}")
# Generate a mixed batch of prompts, some of which can be easily
# predicted by n-gram matching and some which likely cannot.
for kind in random_prompt_type_choices:
word_choices = ["test", "temp", "hello", "where"]
word = random.choice(word_choices)
prompt: Union[str, list[dict[str, Any]]] = ""
if kind == "repeat":
prompt = f"""
please repeat the word '{word}' 10 times.
@ -38,6 +43,21 @@ def test_prompts():
uses the word {word} at least once.
give no other output than that simple sentence without quotes.
"""
elif kind == "mm":
placeholders = [{
"type": "image_url",
"image_url": {
"url":
f"{VLLM_S3_BUCKET_URL}/{VLM_IMAGES_DIR}/stop_sign.jpg"
},
}]
prompt = [
*placeholders,
{
"type": "text",
"text": "The meaning of the image is"
},
]
else:
raise ValueError(f"Unknown prompt type: {kind}")
prompts.append([{"role": "user", "content": prompt}])
@ -57,7 +77,6 @@ def model_name():
def test_ngram_correctness(
monkeypatch: pytest.MonkeyPatch,
test_prompts: list[list[dict[str, Any]]],
sampling_config: SamplingParams,
model_name: str,
):
@ -67,6 +86,7 @@ def test_ngram_correctness(
'''
with monkeypatch.context() as m:
m.setenv("VLLM_USE_V1", "1")
test_prompts = get_test_prompts(mm_enabled=False)
ref_llm = LLM(model=model_name, max_model_len=1024)
ref_outputs = ref_llm.chat(test_prompts, sampling_config)
@ -103,23 +123,32 @@ def test_ngram_correctness(
cleanup_dist_env_and_memory()
@pytest.mark.parametrize("model_setup", [
("eagle", "meta-llama/Llama-3.1-8B-Instruct",
"yuhuili/EAGLE-LLaMA3.1-Instruct-8B", 1),
("eagle3", "meta-llama/Llama-3.1-8B-Instruct",
"yuhuili/EAGLE3-LLaMA3.1-Instruct-8B", 1),
pytest.param(
("eagle", "meta-llama/Llama-4-Scout-17B-16E-Instruct",
"morgendave/EAGLE-Llama-4-Scout-17B-16E-Instruct", 4),
marks=pytest.mark.skip(reason="Skipping due to CI OOM issues")),
],
ids=["llama3_eagle", "llama3_eagle3", "llama4_eagle"])
@pytest.mark.parametrize(
["model_setup", "mm_enabled"], [
(("eagle", "meta-llama/Llama-3.1-8B-Instruct",
"yuhuili/EAGLE-LLaMA3.1-Instruct-8B", 1), False),
(("eagle3", "meta-llama/Llama-3.1-8B-Instruct",
"yuhuili/EAGLE3-LLaMA3.1-Instruct-8B", 1), False),
pytest.param(
("eagle", "meta-llama/Llama-4-Scout-17B-16E-Instruct",
"morgendave/EAGLE-Llama-4-Scout-17B-16E-Instruct", 4),
False,
marks=pytest.mark.skip(reason="Skipping due to CI OOM issues")),
pytest.param(
("eagle", "meta-llama/Llama-4-Scout-17B-16E-Instruct",
"morgendave/EAGLE-Llama-4-Scout-17B-16E-Instruct", 4),
True,
marks=pytest.mark.skip(reason="Skipping due to CI OOM issues")),
],
ids=["llama3_eagle", "llama3_eagle3", "llama4_eagle", "llama4_eagle_mm"])
def test_eagle_correctness(
monkeypatch: pytest.MonkeyPatch,
test_prompts: list[list[dict[str, Any]]],
sampling_config: SamplingParams,
model_setup: tuple[str, str, str, int],
mm_enabled: bool,
):
# Generate test prompts inside the function instead of using fixture
test_prompts = get_test_prompts(mm_enabled)
'''
Compare the outputs of a original LLM and a speculative LLM
should be the same when using eagle speculative decoding.

View File

@ -256,6 +256,7 @@ class Llama4DecoderLayer(nn.Module):
super().__init__()
self.layer_idx = extract_layer_index(prefix)
self.global_layer = config.no_rope_layers[self.layer_idx] == 0
self.hidden_size = config.hidden_size
rope_theta = config.rope_theta
rope_scaling = config.rope_scaling

View File

@ -37,8 +37,9 @@ from vllm.model_executor.model_loader.weight_utils import default_weight_loader
from vllm.model_executor.models.llama4 import (Llama4DecoderLayer,
Llama4ForCausalLM)
from vllm.model_executor.models.utils import extract_layer_index
from vllm.multimodal.inputs import NestedTensors
from .utils import AutoWeightsLoader, maybe_prefix
from .utils import AutoWeightsLoader, maybe_prefix, merge_multimodal_embeddings
logger = init_logger(__name__)
@ -78,15 +79,23 @@ class LlamaModel(nn.Module):
self.norm = RMSNorm(self.config.hidden_size,
eps=self.config.rms_norm_eps)
def get_input_embeddings(
self,
input_ids: torch.Tensor,
) -> torch.Tensor:
return self.embed_tokens(input_ids)
def forward(
self,
input_ids: Optional[torch.Tensor],
positions: torch.Tensor,
hidden_states: torch.Tensor,
inputs_embeds: Optional[torch.Tensor] = None,
) -> tuple[torch.Tensor, torch.Tensor]:
input_embeds = self.embed_tokens(input_ids)
if inputs_embeds is None:
inputs_embeds = self.get_input_embeddings(input_ids)
hidden_states = self.fc(
torch.cat((input_embeds, hidden_states), dim=-1))
torch.cat((inputs_embeds, hidden_states), dim=-1))
residual = None
for layer in self.layers:
hidden_states, residual = layer(
@ -190,8 +199,9 @@ class EagleLlama4ForCausalLM(Llama4ForCausalLM):
input_ids: torch.Tensor,
positions: torch.Tensor,
hidden_states: torch.Tensor,
inputs_embeds: Optional[torch.Tensor] = None,
) -> tuple[torch.Tensor, torch.Tensor]:
return self.model(input_ids, positions, hidden_states)
return self.model(input_ids, positions, hidden_states, inputs_embeds)
def load_weights(self, weights: Iterable[tuple[str,
torch.Tensor]]) -> None:
@ -212,3 +222,20 @@ class EagleLlama4ForCausalLM(Llama4ForCausalLM):
model_weights[name] = loaded_weight
loader.load_weights(model_weights.items())
def get_input_embeddings(
self,
input_ids: torch.Tensor,
multimodal_embeddings: Optional[NestedTensors] = None,
) -> torch.Tensor:
inputs_embeds = self.model.get_input_embeddings(input_ids)
if multimodal_embeddings is not None:
inputs_embeds = merge_multimodal_embeddings(
input_ids,
inputs_embeds,
multimodal_embeddings,
self.config.image_token_index,
)
return inputs_embeds

View File

@ -2,6 +2,7 @@
# SPDX-FileCopyrightText: Copyright contributors to the vLLM project
from collections.abc import Iterable
from typing import Optional
import torch
import torch.nn as nn
@ -148,7 +149,12 @@ class EagleLlamaForCausalLM(LlamaForCausalLM):
input_ids: torch.Tensor,
positions: torch.Tensor,
hidden_states: torch.Tensor,
inputs_embeds: Optional[torch.Tensor] = None,
) -> tuple[torch.Tensor, torch.Tensor]:
if inputs_embeds is not None:
raise NotImplementedError(
f"{type(self).__name__} does not support multimodal inputs yet."
)
return self.model(input_ids, positions, hidden_states)
def load_weights(self, weights: Iterable[tuple[str, torch.Tensor]]):

View File

@ -202,7 +202,12 @@ class Eagle3LlamaForCausalLM(LlamaForCausalLM):
input_ids: torch.Tensor,
positions: torch.Tensor,
hidden_states: torch.Tensor,
inputs_embeds: Optional[torch.Tensor] = None,
) -> tuple[torch.Tensor, torch.Tensor]:
if inputs_embeds is not None:
raise NotImplementedError(
f"{type(self).__name__} does not support multimodal inputs yet."
)
return self.model(input_ids, positions, hidden_states)
def compute_logits(

View File

@ -1,5 +1,7 @@
# SPDX-License-Identifier: Apache-2.0
# SPDX-FileCopyrightText: Copyright contributors to the vLLM project
from typing import Optional
import numpy as np
import torch
import torch.nn as nn
@ -51,6 +53,9 @@ class EagleProposer:
# hidden size (e.g., Llama 3.3 70B).
self.hidden_size = self.draft_model_config.get_hidden_size()
self.is_multimodal_model = vllm_config.model_config \
.is_multimodal_model
self.use_cuda_graph = (self.vllm_config.compilation_config.level
== CompilationLevel.PIECEWISE and
not self.vllm_config.model_config.enforce_eager)
@ -76,6 +81,11 @@ class EagleProposer:
device=device,
dtype=torch.int32)
self.inputs_embeds = torch.zeros(
(self.max_num_tokens, self.hidden_size),
dtype=self.dtype,
device=device)
def propose(
self,
# [num_tokens]
@ -88,6 +98,7 @@ class EagleProposer:
next_token_ids: torch.Tensor,
common_attn_metadata: CommonAttentionMetadata,
sampling_metadata: SamplingMetadata,
mm_embeds: Optional[list[torch.Tensor]] = None,
) -> torch.Tensor:
num_tokens = target_token_ids.shape[0]
batch_size = next_token_ids.shape[0]
@ -128,14 +139,27 @@ class EagleProposer:
# copy inputs to buffer for cudagraph
self.positions[:num_tokens] = target_positions
self.hidden_states[:num_tokens] = target_hidden_states
if self.is_multimodal_model:
input_ids = self.input_ids[:num_tokens]
inputs_embeds = self.model.get_input_embeddings(
input_ids,
multimodal_embeddings=mm_embeds or None,
)
self.inputs_embeds[:num_tokens] = inputs_embeds
inputs_embeds = self.inputs_embeds[:num_input_tokens]
input_ids = None
else:
inputs_embeds = None
input_ids = self.input_ids[:num_input_tokens]
with set_forward_context(per_layer_attn_metadata,
self.vllm_config,
num_tokens=num_input_tokens):
ret_hidden_states = self.model(
self.input_ids[:num_input_tokens],
self.positions[:num_input_tokens],
self.hidden_states[:num_input_tokens],
input_ids=input_ids,
positions=self.positions[:num_input_tokens],
hidden_states=self.hidden_states[:num_input_tokens],
inputs_embeds=inputs_embeds,
)
if self.method == "deepseek_mtp":
last_hidden_states = ret_hidden_states
@ -218,15 +242,24 @@ class EagleProposer:
self.input_ids[:batch_size] = input_ids
self.positions[:batch_size] = clamped_positions
self.hidden_states[:batch_size] = hidden_states
if self.is_multimodal_model:
inputs_embeds = self.model.get_input_embeddings(input_ids)
self.inputs_embeds[:batch_size] = inputs_embeds
inputs_embeds = self.inputs_embeds[:input_batch_size]
input_ids = None
else:
inputs_embeds = None
input_ids = self.input_ids[:input_batch_size]
# Run the model.
with set_forward_context(per_layer_attn_metadata,
self.vllm_config,
num_tokens=input_batch_size):
last_hidden_states, hidden_states = self.model(
self.input_ids[:input_batch_size],
self.positions[:input_batch_size],
self.hidden_states[:input_batch_size],
input_ids=input_ids,
positions=self.positions[:input_batch_size],
hidden_states=self.hidden_states[:input_batch_size],
inputs_embeds=inputs_embeds,
)
hidden_states = hidden_states[:batch_size]
logits = self.model.compute_logits(last_hidden_states[:batch_size],
@ -391,10 +424,18 @@ class EagleProposer:
) -> None:
with set_forward_context(None, self.vllm_config,
num_tokens=num_tokens):
if self.is_multimodal_model:
input_ids = None
inputs_embeds = self.inputs_embeds[:num_tokens]
else:
input_ids = self.input_ids[:num_tokens]
inputs_embeds = None
self.model(
self.input_ids[:num_tokens],
self.positions[:num_tokens],
self.hidden_states[:num_tokens],
input_ids=input_ids,
positions=self.positions[:num_tokens],
hidden_states=self.hidden_states[:num_tokens],
inputs_embeds=inputs_embeds,
)
def validate_same_kv_cache_group(self,

View File

@ -1205,13 +1205,15 @@ class GPUModelRunner(LoRAModelRunnerMixin, KVConnectorModelRunnerMixin):
def _gather_mm_embeddings(
self,
scheduler_output: "SchedulerOutput",
shift_computed_tokens: int = 0,
) -> list[torch.Tensor]:
mm_embeds: list[torch.Tensor] = []
for req_id in self.input_batch.req_ids:
num_scheduled_tokens = scheduler_output.num_scheduled_tokens[
req_id]
req_state = self.requests[req_id]
num_computed_tokens = req_state.num_computed_tokens
num_computed_tokens = \
req_state.num_computed_tokens + shift_computed_tokens
mm_positions = req_state.mm_positions
for i, pos_info in enumerate(mm_positions):
start_pos = pos_info.offset
@ -1858,6 +1860,11 @@ class GPUModelRunner(LoRAModelRunnerMixin, KVConnectorModelRunnerMixin):
[h[token_indices] for h in aux_hidden_states], dim=-1)
else:
target_hidden_states = hidden_states[token_indices]
mm_embeds = None
if self.is_multimodal_model:
mm_embeds = self._gather_mm_embeddings(scheduler_output,
shift_computed_tokens=1)
draft_token_ids = self.drafter.propose(
target_token_ids=target_token_ids,
target_positions=target_positions,
@ -1865,6 +1872,7 @@ class GPUModelRunner(LoRAModelRunnerMixin, KVConnectorModelRunnerMixin):
next_token_ids=next_token_ids,
sampling_metadata=sampling_metadata,
common_attn_metadata=common_attn_metadata,
mm_embeds=mm_embeds,
)
spec_token_ids = draft_token_ids.tolist()
return spec_token_ids