[Multimodal][Speculative Decoding]Eagle Eagle3 mm support, enablement on qwen2.5vl (#22872)

Signed-off-by: Junhong <liujunhong11@huawei.com>
Signed-off-by: Junhong Liu <98734602+LJH-LBJ@users.noreply.github.com>
Co-authored-by: Junhong <liujunhong11@huawei.com>
Co-authored-by: LJH-LBJ <98734602+LJH-LBJ@users.noreply.github.com>
Signed-off-by: yewentao256 <zhyanwentao@126.com>
This commit is contained in:
WeiQing Chen 2025-09-27 11:35:47 +08:00 committed by yewentao256
parent 55971f85c9
commit 38c2df831a
8 changed files with 210 additions and 45 deletions

View File

@ -651,6 +651,9 @@ _SPECULATIVE_DECODING_EXAMPLE_MODELS = {
"MiMoMTPModel": _HfExamplesInfo("XiaomiMiMo/MiMo-7B-RL",
trust_remote_code=True,
speculative_model="XiaomiMiMo/MiMo-7B-RL"),
"Eagle3Qwen2_5vlForCausalLM": _HfExamplesInfo(
"Qwen/Qwen2.5-VL-7B-Instruct",
speculative_model="Rayzl/qwen2.5-vl-7b-eagle3-sgl"),
"Qwen3NextMTP": _HfExamplesInfo("Qwen/Qwen3-Next-80B-A3B-Instruct",
min_transformers_version="4.56.3"),
}

View File

@ -129,6 +129,11 @@ def test_ngram_correctness(
["model_setup", "mm_enabled"],
[
(("eagle3", "Qwen/Qwen3-8B", "AngelSlim/Qwen3-8B_eagle3", 1), False),
pytest.param(("eagle3", "Qwen/Qwen2.5-VL-7B-Instruct",
"Rayzl/qwen2.5-vl-7b-eagle3-sgl", 1),
False,
marks=pytest.mark.skip(reason="Skipping due to its " \
"head_dim not being a a multiple of 32")),
(("eagle", "meta-llama/Llama-3.1-8B-Instruct",
"yuhuili/EAGLE-LLaMA3.1-Instruct-8B", 1), False),
(("eagle3", "meta-llama/Llama-3.1-8B-Instruct",
@ -145,8 +150,8 @@ def test_ngram_correctness(
"eagle618/eagle-deepseek-v3-random", 1), False),
],
ids=[
"qwen3_eagle3", "llama3_eagle", "llama3_eagle3", "llama4_eagle",
"llama4_eagle_mm", "deepseek_eagle"
"qwen3_eagle3", "qwen2_5_vl_eagle3", "llama3_eagle", "llama3_eagle3",
"llama4_eagle", "llama4_eagle_mm", "deepseek_eagle"
])
@pytest.mark.parametrize("attn_backend",
get_attn_backend_list_based_on_platform())

View File

@ -1450,6 +1450,13 @@ def get_samples(args, tokenizer) -> list[SampleRequest]:
):
dataset_class = MLPerfDataset
args.hf_split = "train"
elif (
args.dataset_path in MMStarDataset.SUPPORTED_DATASET_PATHS
or args.hf_name in MMStarDataset.SUPPORTED_DATASET_PATHS
):
dataset_class = MMStarDataset
args.hf_split = "val"
args.hf_subset = None
else:
supported_datasets = set([
dataset_name for cls in HuggingFaceDataset.__subclasses__()
@ -2721,3 +2728,76 @@ class PrefixRepetitionRandomDataset(BenchmarkDataset):
random.shuffle(requests)
return requests
# -----------------------------------------------------------------------------
# MMStar Dataset Implementation
# -----------------------------------------------------------------------------
class MMStarDataset(HuggingFaceDataset):
"""
Lin-Chen/MMStar: https://huggingface.co/datasets/Lin-Chen/MMStar
refer to: https://github.com/sgl-project/SpecForge/pull/106
"""
DEFAULT_OUTPUT_LEN = 128
SUPPORTED_DATASET_PATHS = {"Lin-Chen/MMStar"}
IS_MULTIMODAL = True
def sample(
self,
tokenizer: PreTrainedTokenizerBase,
num_requests: int,
output_len: Optional[int] = None,
enable_multimodal_chat: bool = False,
request_id_prefix: str = "",
no_oversample: bool = False,
**kwargs,
) -> list[SampleRequest]:
# If --hf-output-len is not set, use the default output length.
output_len = (output_len
if output_len is not None else self.DEFAULT_OUTPUT_LEN)
sampled_requests: list[SampleRequest] = []
for ind, item in enumerate(self.data):
if len(sampled_requests) >= num_requests:
break
# Split the question text from options
# (keep only the part before "Options:").
full_q: str = item.get("question", "")
question_text = full_q.split("Options:", 1)[0].strip()
# Multimodal image content.
mm_content = process_image(item["image"])
# Compute prompt token length (note: this is plain text length
# if enable_multimodal_chat is False).
prompt_len = len(tokenizer(question_text).input_ids)
if enable_multimodal_chat:
# If multimodal content should be embedded in the chat message,
# convert to [{"role":"user","content":[...]}]
prompt = self.apply_multimodal_chat_transformation(
question_text, mm_content
)
mm_for_request = None # Already embedded in chat content.
else:
# Default: prompt is plain text,
# image is in mm_content for the bench to assemble.
prompt = question_text
mm_for_request = mm_content
sampled_requests.append(
SampleRequest(
prompt=prompt,
prompt_len=prompt_len,
expected_output_len=output_len,
multi_modal_data=mm_for_request,
request_id=request_id_prefix + str(ind),
)
)
self.maybe_oversample_requests(
sampled_requests, num_requests, request_id_prefix, no_oversample
)
return sampled_requests

View File

@ -8,7 +8,6 @@ import torch
import torch.nn as nn
from transformers import LlamaConfig
from vllm.compilation.decorators import support_torch_compile
from vllm.config import CacheConfig, VllmConfig, get_current_vllm_config
from vllm.logger import init_logger
from vllm.model_executor.layers.layernorm import RMSNorm
@ -19,6 +18,7 @@ from vllm.model_executor.layers.quantization.base_config import (
from vllm.model_executor.layers.vocab_parallel_embedding import (
DEFAULT_VOCAB_PADDING_SIZE, ParallelLMHead, VocabParallelEmbedding)
from vllm.model_executor.model_loader.weight_utils import default_weight_loader
from vllm.model_executor.models.interfaces import MultiModalEmbeddings
from vllm.model_executor.models.llama import (LlamaDecoderLayer,
LlamaForCausalLM)
@ -102,7 +102,6 @@ class LlamaDecoderLayer(LlamaDecoderLayer):
return hidden_states, residual
@support_torch_compile
class LlamaModel(nn.Module):
def __init__(
@ -145,13 +144,21 @@ class LlamaModel(nn.Module):
eps=self.config.rms_norm_eps,
)
def get_input_embeddings(
self,
input_ids: torch.Tensor,
) -> torch.Tensor:
return self.embed_tokens(input_ids)
def forward(
self,
input_ids: torch.Tensor,
positions: torch.Tensor,
hidden_states: torch.Tensor,
input_embeds: Optional[torch.Tensor] = None,
) -> tuple[torch.Tensor, torch.Tensor]:
input_embeds = self.embed_tokens(input_ids)
if input_embeds is None:
input_embeds = self.get_input_embeddings(input_ids)
assert hidden_states.shape[-1] == input_embeds.shape[-1]
residual = None
@ -239,11 +246,7 @@ class Eagle3LlamaForCausalLM(LlamaForCausalLM):
hidden_states: torch.Tensor,
inputs_embeds: Optional[torch.Tensor] = None,
) -> tuple[torch.Tensor, torch.Tensor]:
if inputs_embeds is not None:
raise NotImplementedError(
f"{type(self).__name__} does not support multimodal inputs yet."
)
return self.model(input_ids, positions, hidden_states)
return self.model(input_ids, positions, hidden_states, inputs_embeds)
def compute_logits(
self,
@ -299,3 +302,11 @@ class Eagle3LlamaForCausalLM(LlamaForCausalLM):
skip_substrs=skip_substrs,
)
loader.load_weights(model_weights.items())
def get_input_embeddings(
self,
input_ids: torch.Tensor,
multimodal_embeddings: Optional[MultiModalEmbeddings] = None,
) -> torch.Tensor:
inputs_embeds = self.model.get_input_embeddings(input_ids)
return inputs_embeds

View File

@ -68,7 +68,7 @@ from vllm.transformers_utils.config import uses_mrope
from vllm.utils import is_pin_memory_available
from vllm.utils.tensor_schema import TensorSchema, TensorShape
from .interfaces import (MultiModalEmbeddings, SupportsLoRA,
from .interfaces import (MultiModalEmbeddings, SupportsEagle3, SupportsLoRA,
SupportsMultiModal, SupportsMultiModalPruning,
SupportsPP, SupportsQuant)
from .qwen2_vl import Qwen2VLDummyInputsBuilder as Qwen2_5_VLDummyInputsBuilder
@ -965,7 +965,7 @@ class Qwen2_5_VLMultiModalProcessor(Qwen2VLMultiModalProcessor):
dummy_inputs=Qwen2_5_VLDummyInputsBuilder)
class Qwen2_5_VLForConditionalGeneration(nn.Module, SupportsMultiModal,
SupportsLoRA, SupportsPP,
SupportsQuant,
SupportsQuant, SupportsEagle3,
SupportsMultiModalPruning):
packed_modules_mapping = {
@ -1028,6 +1028,13 @@ class Qwen2_5_VLForConditionalGeneration(nn.Module, SupportsMultiModal,
self.make_empty_intermediate_tensors = (
self.language_model.make_empty_intermediate_tensors)
def set_aux_hidden_state_layers(self, layers: tuple[int, ...]) -> None:
self.language_model.model.aux_hidden_state_layers = layers
def get_eagle3_aux_hidden_state_layers(self) -> tuple[int, ...]:
num_layers = len(self.language_model.model.layers)
return (2, num_layers // 2, num_layers - 3)
def _validate_and_reshape_mm_tensor(self, mm_input: object,
name: str) -> torch.Tensor:
if not isinstance(mm_input, (torch.Tensor, list)):

View File

@ -286,6 +286,7 @@ _SPECULATIVE_DECODING_MODELS = {
"EagleMiniCPMForCausalLM": ("minicpm_eagle", "EagleMiniCPMForCausalLM"),
"Eagle3LlamaForCausalLM": ("llama_eagle3", "Eagle3LlamaForCausalLM"),
"LlamaForCausalLMEagle3": ("llama_eagle3", "Eagle3LlamaForCausalLM"),
"Eagle3Qwen2_5vlForCausalLM": ("llama_eagle3", "Eagle3LlamaForCausalLM"),
"EagleDeepSeekMTPModel": ("deepseek_eagle", "EagleDeepseekV3ForCausalLM"),
"DeepSeekMTPModel": ("deepseek_mtp", "DeepSeekMTP"),
"ErnieMTPModel": ("ernie_mtp", "ErnieMTP"),

View File

@ -80,9 +80,17 @@ class EagleProposer:
self.input_ids = torch.zeros(self.max_num_tokens,
dtype=torch.int32,
device=device)
self.positions = torch.zeros(self.max_num_tokens,
dtype=torch.int64,
device=device)
self.uses_mrope = self.vllm_config.model_config.uses_mrope
if self.uses_mrope:
# M-RoPE need (3, max_num_tokens)
self.mrope_positions = torch.zeros((3, self.max_num_tokens),
dtype=torch.int64,
device=device)
else:
# RoPE need (max_num_tokens,)
self.positions = torch.zeros(self.max_num_tokens,
dtype=torch.int64,
device=device)
self.hidden_states = torch.zeros(
(self.max_num_tokens, self.hidden_size),
dtype=self.dtype,
@ -143,11 +151,22 @@ class EagleProposer:
dtype=torch.int32,
).repeat(max_batch_size, 1)
def _get_positions(self, num_tokens: int):
if self.uses_mrope:
return self.mrope_positions[:, :num_tokens]
return self.positions[:num_tokens]
def _set_positions(self, num_tokens: int, positions: torch.Tensor):
if self.uses_mrope:
self.mrope_positions[:, :num_tokens] = positions
else:
self.positions[:num_tokens] = positions
def propose(
self,
# [num_tokens]
target_token_ids: torch.Tensor,
# [num_tokens]
# [num_tokens] or [3, num_tokens] when M-RoPE is enabled
target_positions: torch.Tensor,
# [num_tokens, hidden_size]
target_hidden_states: torch.Tensor,
@ -198,7 +217,7 @@ class EagleProposer:
else:
num_input_tokens = num_tokens
# copy inputs to buffer for cudagraph
self.positions[:num_tokens] = target_positions
self._set_positions(num_tokens, target_positions)
self.hidden_states[:num_tokens] = target_hidden_states
if self.is_multimodal_model:
input_ids = self.input_ids[:num_tokens]
@ -218,7 +237,7 @@ class EagleProposer:
num_tokens=num_input_tokens):
ret_hidden_states = self.model(
input_ids=input_ids,
positions=self.positions[:num_input_tokens],
positions=self._get_positions(num_input_tokens),
hidden_states=self.hidden_states[:num_input_tokens],
inputs_embeds=inputs_embeds,
)
@ -235,7 +254,10 @@ class EagleProposer:
draft_token_ids = logits.argmax(dim=-1)
return draft_token_ids.view(-1, 1)
positions = target_positions[last_token_indices]
if self.uses_mrope:
positions = target_positions[:, last_token_indices]
else:
positions = target_positions[last_token_indices]
if self.method in ("deepseek_mtp", "ernie_mtp", "longcat_flash_mtp"):
hidden_states = self.hidden_states[last_token_indices]
else:
@ -282,25 +304,34 @@ class EagleProposer:
# cast to int32 is crucial when eagle model is compiled.
# tensor.argmax() returns int64 by default.
input_ids = draft_token_ids_list[-1].int()
positions += 1
# NOTE(woosuk): We should handle the case where the draft model
# generates tokens beyond the max model length. Since it is complex
# to remove such requests from the batch, we keep them in the batch
# but adjust the position ids and slot mappings to avoid the
# out-of-range access during the model execution. The draft tokens
# generated with this adjustment should be ignored.
exceeds_max_model_len = positions >= self.max_model_len
# Mask out the position ids that exceed the max model length.
# Otherwise, we may get out-of-range error in RoPE.
clamped_positions = torch.where(exceeds_max_model_len, 0,
positions)
if self.uses_mrope:
positions += 1
# NOTE(woosuk): We should handle the case where the draft model
# generates tokens beyond the max model length.
# Since it is complex to remove such requests from the batch,
# we keep them in the batch but adjust the position ids
# and slot mappings to avoid the
# out-of-range access during the model execution.
# The draft tokens generated with this adjustment
# should be ignored.
exceeds_max_model_len = positions[0] >= self.max_model_len
# Mask out the position ids that exceed the max model length.
# Otherwise, we may get out-of-range error in RoPE.
clamped_positions = torch.where\
(exceeds_max_model_len.unsqueeze(0), \
torch.zeros_like(positions), positions)
else:
positions += 1
exceeds_max_model_len = positions >= self.max_model_len
clamped_positions = torch.where(exceeds_max_model_len, 0,
positions)
# Increment the sequence lengths.
common_attn_metadata.seq_lens += 1
common_attn_metadata.seq_lens_cpu += 1
# For the requests that exceed the max model length, we set the
# sequence length to 1 to minimize their overheads in attention.
common_attn_metadata.seq_lens.masked_fill_(exceeds_max_model_len,
1)
@ -308,13 +339,22 @@ class EagleProposer:
common_attn_metadata.seq_lens_cpu - 1
# Compute the slot mapping.
block_numbers = clamped_positions // self.block_size
if self.uses_mrope:
# all dimensions of positions are the same
block_numbers = clamped_positions[0] // self.block_size
else:
block_numbers = clamped_positions // self.block_size
block_ids = common_attn_metadata.block_table_tensor.gather(
dim=1, index=block_numbers.view(-1, 1))
block_ids = block_ids.view(-1)
common_attn_metadata.slot_mapping = (
block_ids * self.block_size +
clamped_positions % self.block_size)
if self.uses_mrope:
common_attn_metadata.slot_mapping = (
block_ids * self.block_size +
clamped_positions[0] % self.block_size)
else:
common_attn_metadata.slot_mapping = (
block_ids * self.block_size +
clamped_positions % self.block_size)
# Mask out the slot mappings that exceed the max model length.
# Otherwise, the KV cache will be inadvertently updated with the
# padding tokens.
@ -330,7 +370,7 @@ class EagleProposer:
# copy inputs to buffer for cudagraph
self.input_ids[:batch_size] = input_ids
self.positions[:batch_size] = clamped_positions
self._set_positions(batch_size, clamped_positions)
self.hidden_states[:batch_size] = hidden_states
if self.is_multimodal_model:
inputs_embeds = self.model.get_input_embeddings(input_ids)
@ -347,7 +387,7 @@ class EagleProposer:
num_tokens=input_batch_size):
ret_hidden_states = self.model(
input_ids=input_ids,
positions=self.positions[:input_batch_size],
positions=self._get_positions(input_batch_size),
hidden_states=self.hidden_states[:input_batch_size],
inputs_embeds=inputs_embeds,
)
@ -787,6 +827,11 @@ class EagleProposer:
return spec_common_attn_metadata, token_indices
def get_model_name(self, model: nn.Module) -> str:
if hasattr(model, 'module'): # multi-GPU
model = model.module
return model.__class__.__name__
def load_model(self, target_model: nn.Module) -> None:
draft_model_config = \
self.vllm_config.speculative_config.draft_model_config
@ -820,8 +865,13 @@ class EagleProposer:
if supports_multimodal(target_model):
# handle multimodality
self.model.config.image_token_index = (
target_model.config.image_token_index)
if (self.get_model_name(target_model) ==
"Qwen2_5_VLForConditionalGeneration"):
self.model.config.image_token_index = (
target_model.config.image_token_id)
else:
self.model.config.image_token_index = (
target_model.config.image_token_index)
target_language_model = target_model.get_language_model()
else:
target_language_model = target_model
@ -892,7 +942,7 @@ class EagleProposer:
self.model(
input_ids=input_ids,
positions=self.positions[:num_tokens],
positions=self._get_positions(num_tokens),
hidden_states=self.hidden_states[:num_tokens],
inputs_embeds=inputs_embeds,
)

View File

@ -442,6 +442,16 @@ class GPUModelRunner(LoRAModelRunnerMixin, KVConnectorModelRunnerMixin):
device="cpu",
pin_memory=self.pin_memory)
def _get_positions(self, num_tokens: Any):
if isinstance(num_tokens, int):
if self.uses_mrope:
return self.mrope_positions.gpu[:, :num_tokens]
return self.positions.gpu[:num_tokens]
else:
if self.uses_mrope:
return self.mrope_positions.gpu[:, num_tokens]
return self.positions.gpu[num_tokens]
def _make_buffer(self,
*size: Union[int, torch.SymInt],
dtype: torch.dtype,
@ -2544,8 +2554,7 @@ class GPUModelRunner(LoRAModelRunnerMixin, KVConnectorModelRunnerMixin):
token_indices_to_sample = None
# input_ids can be None for multimodal models.
target_token_ids = self.input_ids.gpu[:num_scheduled_tokens]
# TODO(woosuk): Support M-RoPE.
target_positions = self.positions.gpu[:num_scheduled_tokens]
target_positions = self._get_positions(num_scheduled_tokens)
if self.use_aux_hidden_state_outputs:
assert aux_hidden_states is not None
target_hidden_states = torch.cat(
@ -2570,8 +2579,7 @@ class GPUModelRunner(LoRAModelRunnerMixin, KVConnectorModelRunnerMixin):
valid_sampled_tokens_count)
target_token_ids = self.input_ids.gpu[token_indices]
# TODO(woosuk): Support M-RoPE.
target_positions = self.positions.gpu[token_indices]
target_positions = self._get_positions(token_indices)
if self.use_aux_hidden_state_outputs:
assert aux_hidden_states is not None
target_hidden_states = torch.cat(