[Spec decode] automatically disable mm for text-only draft models (#25667)

Signed-off-by: Jonas Kuebler <kuebj@amazon.com>
This commit is contained in:
Jonas M. Kübler 2025-09-27 02:10:21 +02:00 committed by GitHub
parent 4e33a7ea85
commit 6f5c0931c1
No known key found for this signature in database
GPG Key ID: B5690EEEBB952194
2 changed files with 78 additions and 62 deletions

View File

@ -8,7 +8,7 @@ from typing import Any, Union
import pytest import pytest
import torch import torch
from tests.utils import get_attn_backend_list_based_on_platform from tests.utils import get_attn_backend_list_based_on_platform, large_gpu_mark
from vllm import LLM, SamplingParams from vllm import LLM, SamplingParams
from vllm.assets.base import VLLM_S3_BUCKET_URL from vllm.assets.base import VLLM_S3_BUCKET_URL
from vllm.assets.image import VLM_IMAGES_DIR from vllm.assets.image import VLM_IMAGES_DIR
@ -88,8 +88,6 @@ def test_ngram_correctness(
Compare the outputs of an original LLM and a speculative LLM Compare the outputs of an original LLM and a speculative LLM
should be the same when using ngram speculative decoding. should be the same when using ngram speculative decoding.
''' '''
with monkeypatch.context() as m:
m.setenv("VLLM_USE_V1", "1")
test_prompts = get_test_prompts(mm_enabled=False) test_prompts = get_test_prompts(mm_enabled=False)
ref_llm = LLM(model=model_name, max_model_len=1024) ref_llm = LLM(model=model_name, max_model_len=1024)
@ -127,29 +125,28 @@ def test_ngram_correctness(
cleanup_dist_env_and_memory() cleanup_dist_env_and_memory()
@pytest.mark.parametrize(["model_setup", "mm_enabled"], [ @pytest.mark.parametrize(
["model_setup", "mm_enabled"],
[
(("eagle3", "Qwen/Qwen3-8B", "AngelSlim/Qwen3-8B_eagle3", 1), False), (("eagle3", "Qwen/Qwen3-8B", "AngelSlim/Qwen3-8B_eagle3", 1), False),
(("eagle", "meta-llama/Llama-3.1-8B-Instruct", (("eagle", "meta-llama/Llama-3.1-8B-Instruct",
"yuhuili/EAGLE-LLaMA3.1-Instruct-8B", 1), False), "yuhuili/EAGLE-LLaMA3.1-Instruct-8B", 1), False),
(("eagle3", "meta-llama/Llama-3.1-8B-Instruct", (("eagle3", "meta-llama/Llama-3.1-8B-Instruct",
"yuhuili/EAGLE3-LLaMA3.1-Instruct-8B", 1), False), "yuhuili/EAGLE3-LLaMA3.1-Instruct-8B", 1), False),
pytest.param( pytest.param(("eagle", "meta-llama/Llama-4-Scout-17B-16E-Instruct",
("eagle", "meta-llama/Llama-4-Scout-17B-16E-Instruct",
"morgendave/EAGLE-Llama-4-Scout-17B-16E-Instruct", 4), "morgendave/EAGLE-Llama-4-Scout-17B-16E-Instruct", 4),
False, False,
marks=pytest.mark.skip(reason="Skipping due to CI OOM issues")), marks=large_gpu_mark(min_gb=80)), # works on 4x H100
pytest.param( pytest.param(("eagle", "meta-llama/Llama-4-Scout-17B-16E-Instruct",
("eagle", "meta-llama/Llama-4-Scout-17B-16E-Instruct",
"morgendave/EAGLE-Llama-4-Scout-17B-16E-Instruct", 4), "morgendave/EAGLE-Llama-4-Scout-17B-16E-Instruct", 4),
True, True,
marks=pytest.mark.skip(reason="Skipping due to CI OOM issues")), marks=large_gpu_mark(min_gb=80)), # works on 4x H100
(("eagle", "eagle618/deepseek-v3-random", (("eagle", "eagle618/deepseek-v3-random",
"eagle618/eagle-deepseek-v3-random", 1), False), "eagle618/eagle-deepseek-v3-random", 1), False),
], ],
ids=[ ids=[
"qwen3_eagle3", "llama3_eagle", "llama3_eagle3", "qwen3_eagle3", "llama3_eagle", "llama3_eagle3", "llama4_eagle",
"llama4_eagle", "llama4_eagle_mm", "llama4_eagle_mm", "deepseek_eagle"
"deepseek_eagle"
]) ])
@pytest.mark.parametrize("attn_backend", @pytest.mark.parametrize("attn_backend",
get_attn_backend_list_based_on_platform()) get_attn_backend_list_based_on_platform())
@ -174,7 +171,12 @@ def test_eagle_correctness(
model_setup: (method, model_name, eagle_model_name, tp_size) model_setup: (method, model_name, eagle_model_name, tp_size)
''' '''
with monkeypatch.context() as m: with monkeypatch.context() as m:
m.setenv("VLLM_USE_V1", "1") if "Llama-4-Scout" in model_setup[1] and attn_backend == "FLASH_ATTN":
# Scout requires default backend selection
# because vision encoder has head_dim 88 being incompatible
# with FLASH_ATTN and needs to fall back to Flex Attn
pass
else:
m.setenv("VLLM_MLA_DISABLE", "1") m.setenv("VLLM_MLA_DISABLE", "1")
m.setenv("VLLM_ATTENTION_BACKEND", attn_backend) m.setenv("VLLM_ATTENTION_BACKEND", attn_backend)

View File

@ -804,6 +804,20 @@ class EagleProposer:
self.attn_layer_names = list(draft_attn_layer_names) self.attn_layer_names = list(draft_attn_layer_names)
if self.is_multimodal_model:
# Even if the target model is multimodal, we can also use
# text-only draft models
try:
dummy_input_ids = torch.tensor([[1]],
device=self.input_ids.device)
self.model.get_input_embeddings(dummy_input_ids,
multimodal_embeddings=None)
except (NotImplementedError, AttributeError, TypeError):
logger.warning(
"Draft model does not support multimodal inputs, "
"falling back to text-only mode")
self.is_multimodal_model = False
if supports_multimodal(target_model): if supports_multimodal(target_model):
# handle multimodality # handle multimodality
self.model.config.image_token_index = ( self.model.config.image_token_index = (