[Multimodal][Speculative Decoding]Eagle3 mm support, enablement on qwen3vl (#29594)

Signed-off-by: Tsai, Louie <louie.tsai@intel.com>
Signed-off-by: EanWang211123 <wangyiheng@sangfor.com.cn>
Co-authored-by: Louie Tsai <louie.tsai@intel.com>
Co-authored-by: Cyrus Leung <cyrus.tl.leung@gmail.com>
This commit is contained in:
EanWang211123 2025-11-28 14:05:45 +08:00 committed by GitHub
parent c7ba1f6bc7
commit 37b15e97e8
No known key found for this signature in database
GPG Key ID: B5690EEEBB952194
5 changed files with 45 additions and 5 deletions

View File

@ -913,6 +913,10 @@ _SPECULATIVE_DECODING_EXAMPLE_MODELS = {
"Qwen/Qwen2.5-VL-7B-Instruct",
speculative_model="Rayzl/qwen2.5-vl-7b-eagle3-sgl",
),
"Eagle3Qwen3vlForCausalLM": _HfExamplesInfo(
"Qwen/Qwen3-VL-8B-Instruct",
speculative_model="taobao-mnn/Qwen3-VL-8B-Instruct-Eagle3",
),
"Qwen3NextMTP": _HfExamplesInfo(
"Qwen/Qwen3-Next-80B-A3B-Instruct", min_transformers_version="4.56.3"
),

View File

@ -283,6 +283,19 @@ def test_speculators_model_integration(
["model_setup", "mm_enabled", "enable_chunked_prefill"],
[
(("eagle3", "Qwen/Qwen3-8B", "AngelSlim/Qwen3-8B_eagle3", 1), False, False),
pytest.param(
(
"eagle3",
"Qwen/Qwen3-VL-8B-Instruct",
"taobao-mnn/Qwen3-VL-8B-Instruct-Eagle3",
1,
),
False,
False,
marks=pytest.mark.skip(
reason="architecture of its eagle3 is LlamaForCausalLMEagle3"
),
),
pytest.param(
(
"eagle3",
@ -352,6 +365,7 @@ def test_speculators_model_integration(
],
ids=[
"qwen3_eagle3",
"qwen3_vl_eagle3",
"qwen2_5_vl_eagle3",
"llama3_eagle",
"llama3_eagle3",

View File

@ -89,6 +89,7 @@ from vllm.utils.collection_utils import is_list_of
from .interfaces import (
MultiModalEmbeddings,
SupportsEagle3,
SupportsLoRA,
SupportsMRoPE,
SupportsMultiModal,
@ -1122,9 +1123,14 @@ class Qwen3LLMModel(Qwen3Model):
assert intermediate_tensors is not None
hidden_states = intermediate_tensors["hidden_states"]
residual = intermediate_tensors["residual"]
aux_hidden_states = []
for layer_idx, layer in islice(
enumerate(self.layers), self.start_layer, self.end_layer
):
if layer_idx in self.aux_hidden_state_layers:
aux_hidden_states.append(hidden_states + residual)
hidden_states, residual = layer(
positions,
hidden_states,
@ -1144,6 +1150,9 @@ class Qwen3LLMModel(Qwen3Model):
{"hidden_states": hidden_states, "residual": residual}
)
hidden_states, _ = self.norm(hidden_states, residual)
if len(aux_hidden_states) > 0:
return hidden_states, aux_hidden_states
return hidden_states
@ -1186,7 +1195,12 @@ class Qwen3LLMForCausalLM(Qwen3ForCausalLM):
dummy_inputs=Qwen3VLDummyInputsBuilder,
)
class Qwen3VLForConditionalGeneration(
nn.Module, SupportsMultiModal, SupportsLoRA, SupportsPP, SupportsMRoPE
nn.Module,
SupportsMultiModal,
SupportsLoRA,
SupportsPP,
SupportsMRoPE,
SupportsEagle3,
):
merge_by_field_config = True
multimodal_cpu_fields = {"image_grid_thw", "video_grid_thw"}
@ -1279,6 +1293,13 @@ class Qwen3VLForConditionalGeneration(
self.visual_dim = config.vision_config.out_hidden_size
self.multiscale_dim = self.visual_dim * self.deepstack_num_level
def set_aux_hidden_state_layers(self, layers: tuple[int, ...]) -> None:
self.language_model.model.aux_hidden_state_layers = layers
def get_eagle3_aux_hidden_state_layers(self) -> tuple[int, ...]:
num_layers = len(self.language_model.model.layers)
return (2, num_layers // 2, num_layers - 3)
def _get_deepstack_input_embeds(self, num_tokens: int) -> IntermediateTensors:
# get deepstack_input_embeds from buffer, and clear the buffer
return IntermediateTensors(

View File

@ -414,6 +414,7 @@ _SPECULATIVE_DECODING_MODELS = {
"Eagle3LlamaForCausalLM": ("llama_eagle3", "Eagle3LlamaForCausalLM"),
"LlamaForCausalLMEagle3": ("llama_eagle3", "Eagle3LlamaForCausalLM"),
"Eagle3Qwen2_5vlForCausalLM": ("llama_eagle3", "Eagle3LlamaForCausalLM"),
"Eagle3Qwen3vlForCausalLM": ("llama_eagle3", "Eagle3LlamaForCausalLM"),
"EagleDeepSeekMTPModel": ("deepseek_eagle", "EagleDeepseekV3ForCausalLM"),
"DeepSeekMTPModel": ("deepseek_mtp", "DeepSeekMTP"),
"ErnieMTPModel": ("ernie_mtp", "ErnieMTP"),

View File

@ -1017,10 +1017,10 @@ class EagleProposer:
if supports_multimodal(target_model):
# handle multimodality
if (
self.get_model_name(target_model)
== "Qwen2_5_VLForConditionalGeneration"
):
if self.get_model_name(target_model) in [
"Qwen2_5_VLForConditionalGeneration",
"Qwen3VLForConditionalGeneration",
]:
self.model.config.image_token_index = target_model.config.image_token_id
else:
self.model.config.image_token_index = (