[Spec decode] automatically disable mm for text-only draft models (#25667)

Signed-off-by: Jonas Kuebler <kuebj@amazon.com>
This commit is contained in:
Jonas M. Kübler 2025-09-27 02:10:21 +02:00 committed by GitHub
parent 4e33a7ea85
commit 6f5c0931c1
No known key found for this signature in database
GPG Key ID: B5690EEEBB952194
2 changed files with 78 additions and 62 deletions

View File

@ -8,7 +8,7 @@ from typing import Any, Union
import pytest import pytest
import torch import torch
from tests.utils import get_attn_backend_list_based_on_platform from tests.utils import get_attn_backend_list_based_on_platform, large_gpu_mark
from vllm import LLM, SamplingParams from vllm import LLM, SamplingParams
from vllm.assets.base import VLLM_S3_BUCKET_URL from vllm.assets.base import VLLM_S3_BUCKET_URL
from vllm.assets.image import VLM_IMAGES_DIR from vllm.assets.image import VLM_IMAGES_DIR
@ -88,69 +88,66 @@ def test_ngram_correctness(
Compare the outputs of an original LLM and a speculative LLM Compare the outputs of an original LLM and a speculative LLM
should be the same when using ngram speculative decoding. should be the same when using ngram speculative decoding.
''' '''
with monkeypatch.context() as m: test_prompts = get_test_prompts(mm_enabled=False)
m.setenv("VLLM_USE_V1", "1")
test_prompts = get_test_prompts(mm_enabled=False)
ref_llm = LLM(model=model_name, max_model_len=1024) ref_llm = LLM(model=model_name, max_model_len=1024)
ref_outputs = ref_llm.chat(test_prompts, sampling_config) ref_outputs = ref_llm.chat(test_prompts, sampling_config)
del ref_llm del ref_llm
torch.cuda.empty_cache() torch.cuda.empty_cache()
cleanup_dist_env_and_memory() cleanup_dist_env_and_memory()
spec_llm = LLM( spec_llm = LLM(
model=model_name, model=model_name,
speculative_config={ speculative_config={
"method": "ngram", "method": "ngram",
"prompt_lookup_max": 5, "prompt_lookup_max": 5,
"prompt_lookup_min": 3, "prompt_lookup_min": 3,
"num_speculative_tokens": 3, "num_speculative_tokens": 3,
}, },
max_model_len=1024, max_model_len=1024,
) )
spec_outputs = spec_llm.chat(test_prompts, sampling_config) spec_outputs = spec_llm.chat(test_prompts, sampling_config)
matches = 0 matches = 0
misses = 0 misses = 0
for ref_output, spec_output in zip(ref_outputs, spec_outputs): for ref_output, spec_output in zip(ref_outputs, spec_outputs):
if ref_output.outputs[0].text == spec_output.outputs[0].text: if ref_output.outputs[0].text == spec_output.outputs[0].text:
matches += 1 matches += 1
else: else:
misses += 1 misses += 1
print(f"ref_output: {ref_output.outputs[0].text}") print(f"ref_output: {ref_output.outputs[0].text}")
print(f"spec_output: {spec_output.outputs[0].text}") print(f"spec_output: {spec_output.outputs[0].text}")
# Heuristic: expect at least 66% of the prompts to match exactly # Heuristic: expect at least 66% of the prompts to match exactly
# Upon failure, inspect the outputs to check for inaccuracy. # Upon failure, inspect the outputs to check for inaccuracy.
assert matches >= int(0.66 * len(ref_outputs)) assert matches >= int(0.66 * len(ref_outputs))
del spec_llm del spec_llm
torch.cuda.empty_cache() torch.cuda.empty_cache()
cleanup_dist_env_and_memory() cleanup_dist_env_and_memory()
@pytest.mark.parametrize(["model_setup", "mm_enabled"], [ @pytest.mark.parametrize(
(("eagle3", "Qwen/Qwen3-8B", "AngelSlim/Qwen3-8B_eagle3", 1), False), ["model_setup", "mm_enabled"],
(("eagle", "meta-llama/Llama-3.1-8B-Instruct", [
"yuhuili/EAGLE-LLaMA3.1-Instruct-8B", 1), False), (("eagle3", "Qwen/Qwen3-8B", "AngelSlim/Qwen3-8B_eagle3", 1), False),
(("eagle3", "meta-llama/Llama-3.1-8B-Instruct", (("eagle", "meta-llama/Llama-3.1-8B-Instruct",
"yuhuili/EAGLE3-LLaMA3.1-Instruct-8B", 1), False), "yuhuili/EAGLE-LLaMA3.1-Instruct-8B", 1), False),
pytest.param( (("eagle3", "meta-llama/Llama-3.1-8B-Instruct",
("eagle", "meta-llama/Llama-4-Scout-17B-16E-Instruct", "yuhuili/EAGLE3-LLaMA3.1-Instruct-8B", 1), False),
"morgendave/EAGLE-Llama-4-Scout-17B-16E-Instruct", 4), pytest.param(("eagle", "meta-llama/Llama-4-Scout-17B-16E-Instruct",
False, "morgendave/EAGLE-Llama-4-Scout-17B-16E-Instruct", 4),
marks=pytest.mark.skip(reason="Skipping due to CI OOM issues")), False,
pytest.param( marks=large_gpu_mark(min_gb=80)), # works on 4x H100
("eagle", "meta-llama/Llama-4-Scout-17B-16E-Instruct", pytest.param(("eagle", "meta-llama/Llama-4-Scout-17B-16E-Instruct",
"morgendave/EAGLE-Llama-4-Scout-17B-16E-Instruct", 4), "morgendave/EAGLE-Llama-4-Scout-17B-16E-Instruct", 4),
True, True,
marks=pytest.mark.skip(reason="Skipping due to CI OOM issues")), marks=large_gpu_mark(min_gb=80)), # works on 4x H100
(("eagle", "eagle618/deepseek-v3-random", (("eagle", "eagle618/deepseek-v3-random",
"eagle618/eagle-deepseek-v3-random", 1), False), "eagle618/eagle-deepseek-v3-random", 1), False),
], ],
ids=[ ids=[
"qwen3_eagle3", "llama3_eagle", "llama3_eagle3", "qwen3_eagle3", "llama3_eagle", "llama3_eagle3", "llama4_eagle",
"llama4_eagle", "llama4_eagle_mm", "llama4_eagle_mm", "deepseek_eagle"
"deepseek_eagle" ])
])
@pytest.mark.parametrize("attn_backend", @pytest.mark.parametrize("attn_backend",
get_attn_backend_list_based_on_platform()) get_attn_backend_list_based_on_platform())
def test_eagle_correctness( def test_eagle_correctness(
@ -174,9 +171,14 @@ def test_eagle_correctness(
model_setup: (method, model_name, eagle_model_name, tp_size) model_setup: (method, model_name, eagle_model_name, tp_size)
''' '''
with monkeypatch.context() as m: with monkeypatch.context() as m:
m.setenv("VLLM_USE_V1", "1") if "Llama-4-Scout" in model_setup[1] and attn_backend == "FLASH_ATTN":
m.setenv("VLLM_MLA_DISABLE", "1") # Scout requires default backend selection
m.setenv("VLLM_ATTENTION_BACKEND", attn_backend) # because vision encoder has head_dim 88 being incompatible
# with FLASH_ATTN and needs to fall back to Flex Attn
pass
else:
m.setenv("VLLM_MLA_DISABLE", "1")
m.setenv("VLLM_ATTENTION_BACKEND", attn_backend)
if (attn_backend == "TRITON_ATTN" and not current_platform.is_rocm()): if (attn_backend == "TRITON_ATTN" and not current_platform.is_rocm()):
pytest.skip("TRITON_ATTN does not support " pytest.skip("TRITON_ATTN does not support "

View File

@ -804,6 +804,20 @@ class EagleProposer:
self.attn_layer_names = list(draft_attn_layer_names) self.attn_layer_names = list(draft_attn_layer_names)
if self.is_multimodal_model:
# Even if the target model is multimodal, we can also use
# text-only draft models
try:
dummy_input_ids = torch.tensor([[1]],
device=self.input_ids.device)
self.model.get_input_embeddings(dummy_input_ids,
multimodal_embeddings=None)
except (NotImplementedError, AttributeError, TypeError):
logger.warning(
"Draft model does not support multimodal inputs, "
"falling back to text-only mode")
self.is_multimodal_model = False
if supports_multimodal(target_model): if supports_multimodal(target_model):
# handle multimodality # handle multimodality
self.model.config.image_token_index = ( self.model.config.image_token_index = (