[Meta] Official Eagle mm support, first enablement on llama4 (#20788)

Signed-off-by: morgendave <morgendave@gmail.com>
Co-authored-by: Roger Wang <hey@rogerw.me>
This commit is contained in:
zhiweiz 2025-07-31 10:35:07 -07:00 committed by GitHub
parent 53c21e492e
commit 9e0726e5bf
No known key found for this signature in database
GPG Key ID: B5690EEEBB952194
8 changed files with 205 additions and 36 deletions

View File

@ -13,6 +13,38 @@ except ImportError:
from argparse import ArgumentParser as FlexibleArgumentParser from argparse import ArgumentParser as FlexibleArgumentParser
QUESTION = "What is the content of each image?"
IMAGE_URLS = [
"https://upload.wikimedia.org/wikipedia/commons/d/da/2015_Kaczka_krzy%C5%BCowka_w_wodzie_%28samiec%29.jpg",
"https://upload.wikimedia.org/wikipedia/commons/7/77/002_The_lion_king_Snyggve_in_the_Serengeti_National_Park_Photo_by_Giles_Laurent.jpg",
"https://upload.wikimedia.org/wikipedia/commons/2/26/Ultramarine_Flycatcher_%28Ficedula_superciliaris%29_Naggar%2C_Himachal_Pradesh%2C_2013_%28cropped%29.JPG",
"https://upload.wikimedia.org/wikipedia/commons/thumb/e/e5/Anim1754_-_Flickr_-_NOAA_Photo_Library_%281%29.jpg/2560px-Anim1754_-_Flickr_-_NOAA_Photo_Library_%281%29.jpg",
"https://upload.wikimedia.org/wikipedia/commons/d/d4/Starfish%2C_Caswell_Bay_-_geograph.org.uk_-_409413.jpg",
"https://upload.wikimedia.org/wikipedia/commons/6/69/Grapevinesnail_01.jpg",
"https://upload.wikimedia.org/wikipedia/commons/thumb/0/0b/Texas_invasive_Musk_Thistle_1.jpg/1920px-Texas_invasive_Musk_Thistle_1.jpg",
"https://upload.wikimedia.org/wikipedia/commons/thumb/7/7a/Huskiesatrest.jpg/2880px-Huskiesatrest.jpg",
"https://upload.wikimedia.org/wikipedia/commons/thumb/6/68/Orange_tabby_cat_sitting_on_fallen_leaves-Hisashi-01A.jpg/1920px-Orange_tabby_cat_sitting_on_fallen_leaves-Hisashi-01A.jpg",
"https://upload.wikimedia.org/wikipedia/commons/3/30/George_the_amazing_guinea_pig.jpg",
"https://upload.wikimedia.org/wikipedia/commons/thumb/1/1f/Oryctolagus_cuniculus_Rcdo.jpg/1920px-Oryctolagus_cuniculus_Rcdo.jpg",
"https://upload.wikimedia.org/wikipedia/commons/9/98/Horse-and-pony.jpg",
]
def get_custom_mm_prompts(num_prompts):
prompts = []
for url in IMAGE_URLS:
prompts.append(
[
{"type": "image_url", "image_url": {"url": url}},
{"type": "text", "text": QUESTION},
]
)
if num_prompts > len(IMAGE_URLS):
prompts = prompts * (num_prompts // len(IMAGE_URLS) + 1)
return [[{"role": "user", "content": prompt}] for prompt in prompts[:num_prompts]]
def parse_args(): def parse_args():
parser = FlexibleArgumentParser() parser = FlexibleArgumentParser()
add_dataset_parser(parser) add_dataset_parser(parser)
@ -35,6 +67,7 @@ def parse_args():
parser.add_argument("--output-len", type=int, default=256) parser.add_argument("--output-len", type=int, default=256)
parser.add_argument("--model-dir", type=str, default=None) parser.add_argument("--model-dir", type=str, default=None)
parser.add_argument("--eagle-dir", type=str, default=None) parser.add_argument("--eagle-dir", type=str, default=None)
parser.add_argument("--custom-mm-prompts", action="store_true")
return parser.parse_args() return parser.parse_args()
@ -44,14 +77,26 @@ def main():
model_dir = args.model_dir model_dir = args.model_dir
if args.model_dir is None: if args.model_dir is None:
if args.custom_mm_prompts:
raise ValueError(
"custom_mm_prompts requires mm based models"
"default llama3.1-8b-instruct is not mm based"
"please specify model_dir to give a mm based model"
)
model_dir = "meta-llama/Llama-3.1-8B-Instruct" model_dir = "meta-llama/Llama-3.1-8B-Instruct"
tokenizer = AutoTokenizer.from_pretrained(model_dir) tokenizer = AutoTokenizer.from_pretrained(model_dir)
args.custom_skip_chat_template = True
prompts = get_samples(args, tokenizer) if not args.custom_mm_prompts:
# add_special_tokens is False to avoid adding bos twice when using chat templates prompts = get_samples(args, tokenizer)
prompt_ids = [ # add_special_tokens is False to avoid adding bos twice
tokenizer.encode(prompt.prompt, add_special_tokens=False) for prompt in prompts # when using chat templates
] prompt_ids = [
tokenizer.encode(prompt.prompt, add_special_tokens=False)
for prompt in prompts
]
else:
prompts = get_custom_mm_prompts(args.num_prompts)
if args.method == "eagle" or args.method == "eagle3": if args.method == "eagle" or args.method == "eagle3":
eagle_dir = args.eagle_dir eagle_dir = args.eagle_dir
@ -85,10 +130,17 @@ def main():
speculative_config=speculative_config, speculative_config=speculative_config,
disable_log_stats=False, disable_log_stats=False,
max_model_len=16384, max_model_len=16384,
limit_mm_per_prompt={"image": 5},
disable_chunked_mm_input=True,
) )
sampling_params = SamplingParams(temperature=args.temp, max_tokens=args.output_len) sampling_params = SamplingParams(temperature=args.temp, max_tokens=args.output_len)
outputs = llm.generate(prompt_token_ids=prompt_ids, sampling_params=sampling_params) if not args.custom_mm_prompts:
outputs = llm.generate(
prompt_token_ids=prompt_ids, sampling_params=sampling_params
)
else:
outputs = llm.chat(prompts, sampling_params=sampling_params)
# print the generated text # print the generated text
if args.print_output: if args.print_output:

View File

@ -3,29 +3,34 @@
from __future__ import annotations from __future__ import annotations
import random import random
from typing import Any from typing import Any, Union
import pytest import pytest
import torch import torch
from vllm import LLM, SamplingParams from vllm import LLM, SamplingParams
from vllm.assets.base import VLLM_S3_BUCKET_URL
from vllm.assets.image import VLM_IMAGES_DIR
from vllm.distributed import cleanup_dist_env_and_memory from vllm.distributed import cleanup_dist_env_and_memory
@pytest.fixture def get_test_prompts(mm_enabled: bool):
def test_prompts():
prompt_types = ["repeat", "sentence"] prompt_types = ["repeat", "sentence"]
if mm_enabled:
prompt_types.append("mm")
num_prompts = 100 num_prompts = 100
prompts = [] prompts = []
random.seed(0) random.seed(0)
random_prompt_type_choices = random.choices(prompt_types, k=num_prompts) random_prompt_type_choices = random.choices(prompt_types, k=num_prompts)
print(f"Prompt types: {random_prompt_type_choices}")
# Generate a mixed batch of prompts, some of which can be easily # Generate a mixed batch of prompts, some of which can be easily
# predicted by n-gram matching and some which likely cannot. # predicted by n-gram matching and some which likely cannot.
for kind in random_prompt_type_choices: for kind in random_prompt_type_choices:
word_choices = ["test", "temp", "hello", "where"] word_choices = ["test", "temp", "hello", "where"]
word = random.choice(word_choices) word = random.choice(word_choices)
prompt: Union[str, list[dict[str, Any]]] = ""
if kind == "repeat": if kind == "repeat":
prompt = f""" prompt = f"""
please repeat the word '{word}' 10 times. please repeat the word '{word}' 10 times.
@ -38,6 +43,21 @@ def test_prompts():
uses the word {word} at least once. uses the word {word} at least once.
give no other output than that simple sentence without quotes. give no other output than that simple sentence without quotes.
""" """
elif kind == "mm":
placeholders = [{
"type": "image_url",
"image_url": {
"url":
f"{VLLM_S3_BUCKET_URL}/{VLM_IMAGES_DIR}/stop_sign.jpg"
},
}]
prompt = [
*placeholders,
{
"type": "text",
"text": "The meaning of the image is"
},
]
else: else:
raise ValueError(f"Unknown prompt type: {kind}") raise ValueError(f"Unknown prompt type: {kind}")
prompts.append([{"role": "user", "content": prompt}]) prompts.append([{"role": "user", "content": prompt}])
@ -57,7 +77,6 @@ def model_name():
def test_ngram_correctness( def test_ngram_correctness(
monkeypatch: pytest.MonkeyPatch, monkeypatch: pytest.MonkeyPatch,
test_prompts: list[list[dict[str, Any]]],
sampling_config: SamplingParams, sampling_config: SamplingParams,
model_name: str, model_name: str,
): ):
@ -67,6 +86,7 @@ def test_ngram_correctness(
''' '''
with monkeypatch.context() as m: with monkeypatch.context() as m:
m.setenv("VLLM_USE_V1", "1") m.setenv("VLLM_USE_V1", "1")
test_prompts = get_test_prompts(mm_enabled=False)
ref_llm = LLM(model=model_name, max_model_len=1024) ref_llm = LLM(model=model_name, max_model_len=1024)
ref_outputs = ref_llm.chat(test_prompts, sampling_config) ref_outputs = ref_llm.chat(test_prompts, sampling_config)
@ -103,23 +123,32 @@ def test_ngram_correctness(
cleanup_dist_env_and_memory() cleanup_dist_env_and_memory()
@pytest.mark.parametrize("model_setup", [ @pytest.mark.parametrize(
("eagle", "meta-llama/Llama-3.1-8B-Instruct", ["model_setup", "mm_enabled"], [
"yuhuili/EAGLE-LLaMA3.1-Instruct-8B", 1), (("eagle", "meta-llama/Llama-3.1-8B-Instruct",
("eagle3", "meta-llama/Llama-3.1-8B-Instruct", "yuhuili/EAGLE-LLaMA3.1-Instruct-8B", 1), False),
"yuhuili/EAGLE3-LLaMA3.1-Instruct-8B", 1), (("eagle3", "meta-llama/Llama-3.1-8B-Instruct",
pytest.param( "yuhuili/EAGLE3-LLaMA3.1-Instruct-8B", 1), False),
("eagle", "meta-llama/Llama-4-Scout-17B-16E-Instruct", pytest.param(
"morgendave/EAGLE-Llama-4-Scout-17B-16E-Instruct", 4), ("eagle", "meta-llama/Llama-4-Scout-17B-16E-Instruct",
marks=pytest.mark.skip(reason="Skipping due to CI OOM issues")), "morgendave/EAGLE-Llama-4-Scout-17B-16E-Instruct", 4),
], False,
ids=["llama3_eagle", "llama3_eagle3", "llama4_eagle"]) marks=pytest.mark.skip(reason="Skipping due to CI OOM issues")),
pytest.param(
("eagle", "meta-llama/Llama-4-Scout-17B-16E-Instruct",
"morgendave/EAGLE-Llama-4-Scout-17B-16E-Instruct", 4),
True,
marks=pytest.mark.skip(reason="Skipping due to CI OOM issues")),
],
ids=["llama3_eagle", "llama3_eagle3", "llama4_eagle", "llama4_eagle_mm"])
def test_eagle_correctness( def test_eagle_correctness(
monkeypatch: pytest.MonkeyPatch, monkeypatch: pytest.MonkeyPatch,
test_prompts: list[list[dict[str, Any]]],
sampling_config: SamplingParams, sampling_config: SamplingParams,
model_setup: tuple[str, str, str, int], model_setup: tuple[str, str, str, int],
mm_enabled: bool,
): ):
# Generate test prompts inside the function instead of using fixture
test_prompts = get_test_prompts(mm_enabled)
''' '''
Compare the outputs of a original LLM and a speculative LLM Compare the outputs of a original LLM and a speculative LLM
should be the same when using eagle speculative decoding. should be the same when using eagle speculative decoding.

View File

@ -256,6 +256,7 @@ class Llama4DecoderLayer(nn.Module):
super().__init__() super().__init__()
self.layer_idx = extract_layer_index(prefix) self.layer_idx = extract_layer_index(prefix)
self.global_layer = config.no_rope_layers[self.layer_idx] == 0
self.hidden_size = config.hidden_size self.hidden_size = config.hidden_size
rope_theta = config.rope_theta rope_theta = config.rope_theta
rope_scaling = config.rope_scaling rope_scaling = config.rope_scaling

View File

@ -37,8 +37,9 @@ from vllm.model_executor.model_loader.weight_utils import default_weight_loader
from vllm.model_executor.models.llama4 import (Llama4DecoderLayer, from vllm.model_executor.models.llama4 import (Llama4DecoderLayer,
Llama4ForCausalLM) Llama4ForCausalLM)
from vllm.model_executor.models.utils import extract_layer_index from vllm.model_executor.models.utils import extract_layer_index
from vllm.multimodal.inputs import NestedTensors
from .utils import AutoWeightsLoader, maybe_prefix from .utils import AutoWeightsLoader, maybe_prefix, merge_multimodal_embeddings
logger = init_logger(__name__) logger = init_logger(__name__)
@ -78,15 +79,23 @@ class LlamaModel(nn.Module):
self.norm = RMSNorm(self.config.hidden_size, self.norm = RMSNorm(self.config.hidden_size,
eps=self.config.rms_norm_eps) eps=self.config.rms_norm_eps)
def get_input_embeddings(
self,
input_ids: torch.Tensor,
) -> torch.Tensor:
return self.embed_tokens(input_ids)
def forward( def forward(
self, self,
input_ids: Optional[torch.Tensor], input_ids: Optional[torch.Tensor],
positions: torch.Tensor, positions: torch.Tensor,
hidden_states: torch.Tensor, hidden_states: torch.Tensor,
inputs_embeds: Optional[torch.Tensor] = None,
) -> tuple[torch.Tensor, torch.Tensor]: ) -> tuple[torch.Tensor, torch.Tensor]:
input_embeds = self.embed_tokens(input_ids) if inputs_embeds is None:
inputs_embeds = self.get_input_embeddings(input_ids)
hidden_states = self.fc( hidden_states = self.fc(
torch.cat((input_embeds, hidden_states), dim=-1)) torch.cat((inputs_embeds, hidden_states), dim=-1))
residual = None residual = None
for layer in self.layers: for layer in self.layers:
hidden_states, residual = layer( hidden_states, residual = layer(
@ -190,8 +199,9 @@ class EagleLlama4ForCausalLM(Llama4ForCausalLM):
input_ids: torch.Tensor, input_ids: torch.Tensor,
positions: torch.Tensor, positions: torch.Tensor,
hidden_states: torch.Tensor, hidden_states: torch.Tensor,
inputs_embeds: Optional[torch.Tensor] = None,
) -> tuple[torch.Tensor, torch.Tensor]: ) -> tuple[torch.Tensor, torch.Tensor]:
return self.model(input_ids, positions, hidden_states) return self.model(input_ids, positions, hidden_states, inputs_embeds)
def load_weights(self, weights: Iterable[tuple[str, def load_weights(self, weights: Iterable[tuple[str,
torch.Tensor]]) -> None: torch.Tensor]]) -> None:
@ -212,3 +222,20 @@ class EagleLlama4ForCausalLM(Llama4ForCausalLM):
model_weights[name] = loaded_weight model_weights[name] = loaded_weight
loader.load_weights(model_weights.items()) loader.load_weights(model_weights.items())
def get_input_embeddings(
self,
input_ids: torch.Tensor,
multimodal_embeddings: Optional[NestedTensors] = None,
) -> torch.Tensor:
inputs_embeds = self.model.get_input_embeddings(input_ids)
if multimodal_embeddings is not None:
inputs_embeds = merge_multimodal_embeddings(
input_ids,
inputs_embeds,
multimodal_embeddings,
self.config.image_token_index,
)
return inputs_embeds

View File

@ -2,6 +2,7 @@
# SPDX-FileCopyrightText: Copyright contributors to the vLLM project # SPDX-FileCopyrightText: Copyright contributors to the vLLM project
from collections.abc import Iterable from collections.abc import Iterable
from typing import Optional
import torch import torch
import torch.nn as nn import torch.nn as nn
@ -148,7 +149,12 @@ class EagleLlamaForCausalLM(LlamaForCausalLM):
input_ids: torch.Tensor, input_ids: torch.Tensor,
positions: torch.Tensor, positions: torch.Tensor,
hidden_states: torch.Tensor, hidden_states: torch.Tensor,
inputs_embeds: Optional[torch.Tensor] = None,
) -> tuple[torch.Tensor, torch.Tensor]: ) -> tuple[torch.Tensor, torch.Tensor]:
if inputs_embeds is not None:
raise NotImplementedError(
f"{type(self).__name__} does not support multimodal inputs yet."
)
return self.model(input_ids, positions, hidden_states) return self.model(input_ids, positions, hidden_states)
def load_weights(self, weights: Iterable[tuple[str, torch.Tensor]]): def load_weights(self, weights: Iterable[tuple[str, torch.Tensor]]):

View File

@ -202,7 +202,12 @@ class Eagle3LlamaForCausalLM(LlamaForCausalLM):
input_ids: torch.Tensor, input_ids: torch.Tensor,
positions: torch.Tensor, positions: torch.Tensor,
hidden_states: torch.Tensor, hidden_states: torch.Tensor,
inputs_embeds: Optional[torch.Tensor] = None,
) -> tuple[torch.Tensor, torch.Tensor]: ) -> tuple[torch.Tensor, torch.Tensor]:
if inputs_embeds is not None:
raise NotImplementedError(
f"{type(self).__name__} does not support multimodal inputs yet."
)
return self.model(input_ids, positions, hidden_states) return self.model(input_ids, positions, hidden_states)
def compute_logits( def compute_logits(

View File

@ -1,5 +1,7 @@
# SPDX-License-Identifier: Apache-2.0 # SPDX-License-Identifier: Apache-2.0
# SPDX-FileCopyrightText: Copyright contributors to the vLLM project # SPDX-FileCopyrightText: Copyright contributors to the vLLM project
from typing import Optional
import numpy as np import numpy as np
import torch import torch
import torch.nn as nn import torch.nn as nn
@ -51,6 +53,9 @@ class EagleProposer:
# hidden size (e.g., Llama 3.3 70B). # hidden size (e.g., Llama 3.3 70B).
self.hidden_size = self.draft_model_config.get_hidden_size() self.hidden_size = self.draft_model_config.get_hidden_size()
self.is_multimodal_model = vllm_config.model_config \
.is_multimodal_model
self.use_cuda_graph = (self.vllm_config.compilation_config.level self.use_cuda_graph = (self.vllm_config.compilation_config.level
== CompilationLevel.PIECEWISE and == CompilationLevel.PIECEWISE and
not self.vllm_config.model_config.enforce_eager) not self.vllm_config.model_config.enforce_eager)
@ -76,6 +81,11 @@ class EagleProposer:
device=device, device=device,
dtype=torch.int32) dtype=torch.int32)
self.inputs_embeds = torch.zeros(
(self.max_num_tokens, self.hidden_size),
dtype=self.dtype,
device=device)
def propose( def propose(
self, self,
# [num_tokens] # [num_tokens]
@ -88,6 +98,7 @@ class EagleProposer:
next_token_ids: torch.Tensor, next_token_ids: torch.Tensor,
common_attn_metadata: CommonAttentionMetadata, common_attn_metadata: CommonAttentionMetadata,
sampling_metadata: SamplingMetadata, sampling_metadata: SamplingMetadata,
mm_embeds: Optional[list[torch.Tensor]] = None,
) -> torch.Tensor: ) -> torch.Tensor:
num_tokens = target_token_ids.shape[0] num_tokens = target_token_ids.shape[0]
batch_size = next_token_ids.shape[0] batch_size = next_token_ids.shape[0]
@ -128,14 +139,27 @@ class EagleProposer:
# copy inputs to buffer for cudagraph # copy inputs to buffer for cudagraph
self.positions[:num_tokens] = target_positions self.positions[:num_tokens] = target_positions
self.hidden_states[:num_tokens] = target_hidden_states self.hidden_states[:num_tokens] = target_hidden_states
if self.is_multimodal_model:
input_ids = self.input_ids[:num_tokens]
inputs_embeds = self.model.get_input_embeddings(
input_ids,
multimodal_embeddings=mm_embeds or None,
)
self.inputs_embeds[:num_tokens] = inputs_embeds
inputs_embeds = self.inputs_embeds[:num_input_tokens]
input_ids = None
else:
inputs_embeds = None
input_ids = self.input_ids[:num_input_tokens]
with set_forward_context(per_layer_attn_metadata, with set_forward_context(per_layer_attn_metadata,
self.vllm_config, self.vllm_config,
num_tokens=num_input_tokens): num_tokens=num_input_tokens):
ret_hidden_states = self.model( ret_hidden_states = self.model(
self.input_ids[:num_input_tokens], input_ids=input_ids,
self.positions[:num_input_tokens], positions=self.positions[:num_input_tokens],
self.hidden_states[:num_input_tokens], hidden_states=self.hidden_states[:num_input_tokens],
inputs_embeds=inputs_embeds,
) )
if self.method == "deepseek_mtp": if self.method == "deepseek_mtp":
last_hidden_states = ret_hidden_states last_hidden_states = ret_hidden_states
@ -218,15 +242,24 @@ class EagleProposer:
self.input_ids[:batch_size] = input_ids self.input_ids[:batch_size] = input_ids
self.positions[:batch_size] = clamped_positions self.positions[:batch_size] = clamped_positions
self.hidden_states[:batch_size] = hidden_states self.hidden_states[:batch_size] = hidden_states
if self.is_multimodal_model:
inputs_embeds = self.model.get_input_embeddings(input_ids)
self.inputs_embeds[:batch_size] = inputs_embeds
inputs_embeds = self.inputs_embeds[:input_batch_size]
input_ids = None
else:
inputs_embeds = None
input_ids = self.input_ids[:input_batch_size]
# Run the model. # Run the model.
with set_forward_context(per_layer_attn_metadata, with set_forward_context(per_layer_attn_metadata,
self.vllm_config, self.vllm_config,
num_tokens=input_batch_size): num_tokens=input_batch_size):
last_hidden_states, hidden_states = self.model( last_hidden_states, hidden_states = self.model(
self.input_ids[:input_batch_size], input_ids=input_ids,
self.positions[:input_batch_size], positions=self.positions[:input_batch_size],
self.hidden_states[:input_batch_size], hidden_states=self.hidden_states[:input_batch_size],
inputs_embeds=inputs_embeds,
) )
hidden_states = hidden_states[:batch_size] hidden_states = hidden_states[:batch_size]
logits = self.model.compute_logits(last_hidden_states[:batch_size], logits = self.model.compute_logits(last_hidden_states[:batch_size],
@ -391,10 +424,18 @@ class EagleProposer:
) -> None: ) -> None:
with set_forward_context(None, self.vllm_config, with set_forward_context(None, self.vllm_config,
num_tokens=num_tokens): num_tokens=num_tokens):
if self.is_multimodal_model:
input_ids = None
inputs_embeds = self.inputs_embeds[:num_tokens]
else:
input_ids = self.input_ids[:num_tokens]
inputs_embeds = None
self.model( self.model(
self.input_ids[:num_tokens], input_ids=input_ids,
self.positions[:num_tokens], positions=self.positions[:num_tokens],
self.hidden_states[:num_tokens], hidden_states=self.hidden_states[:num_tokens],
inputs_embeds=inputs_embeds,
) )
def validate_same_kv_cache_group(self, def validate_same_kv_cache_group(self,

View File

@ -1205,13 +1205,15 @@ class GPUModelRunner(LoRAModelRunnerMixin, KVConnectorModelRunnerMixin):
def _gather_mm_embeddings( def _gather_mm_embeddings(
self, self,
scheduler_output: "SchedulerOutput", scheduler_output: "SchedulerOutput",
shift_computed_tokens: int = 0,
) -> list[torch.Tensor]: ) -> list[torch.Tensor]:
mm_embeds: list[torch.Tensor] = [] mm_embeds: list[torch.Tensor] = []
for req_id in self.input_batch.req_ids: for req_id in self.input_batch.req_ids:
num_scheduled_tokens = scheduler_output.num_scheduled_tokens[ num_scheduled_tokens = scheduler_output.num_scheduled_tokens[
req_id] req_id]
req_state = self.requests[req_id] req_state = self.requests[req_id]
num_computed_tokens = req_state.num_computed_tokens num_computed_tokens = \
req_state.num_computed_tokens + shift_computed_tokens
mm_positions = req_state.mm_positions mm_positions = req_state.mm_positions
for i, pos_info in enumerate(mm_positions): for i, pos_info in enumerate(mm_positions):
start_pos = pos_info.offset start_pos = pos_info.offset
@ -1858,6 +1860,11 @@ class GPUModelRunner(LoRAModelRunnerMixin, KVConnectorModelRunnerMixin):
[h[token_indices] for h in aux_hidden_states], dim=-1) [h[token_indices] for h in aux_hidden_states], dim=-1)
else: else:
target_hidden_states = hidden_states[token_indices] target_hidden_states = hidden_states[token_indices]
mm_embeds = None
if self.is_multimodal_model:
mm_embeds = self._gather_mm_embeddings(scheduler_output,
shift_computed_tokens=1)
draft_token_ids = self.drafter.propose( draft_token_ids = self.drafter.propose(
target_token_ids=target_token_ids, target_token_ids=target_token_ids,
target_positions=target_positions, target_positions=target_positions,
@ -1865,6 +1872,7 @@ class GPUModelRunner(LoRAModelRunnerMixin, KVConnectorModelRunnerMixin):
next_token_ids=next_token_ids, next_token_ids=next_token_ids,
sampling_metadata=sampling_metadata, sampling_metadata=sampling_metadata,
common_attn_metadata=common_attn_metadata, common_attn_metadata=common_attn_metadata,
mm_embeds=mm_embeds,
) )
spec_token_ids = draft_token_ids.tolist() spec_token_ids = draft_token_ids.tolist()
return spec_token_ids return spec_token_ids