From 93ba7648d0be15cdfcfea3acdc40adbd92bfbad2 Mon Sep 17 00:00:00 2001 From: =?UTF-8?q?Jonas=20M=2E=20K=C3=BCbler?= <44084297+jmkuebler@users.noreply.github.com> Date: Sat, 27 Sep 2025 02:10:21 +0200 Subject: [PATCH] [Spec decode] automatically disable mm for text-only draft models (#25667) Signed-off-by: Jonas Kuebler Signed-off-by: yewentao256 --- tests/v1/e2e/test_spec_decode.py | 126 ++++++++++++++++--------------- vllm/v1/spec_decode/eagle.py | 14 ++++ 2 files changed, 78 insertions(+), 62 deletions(-) diff --git a/tests/v1/e2e/test_spec_decode.py b/tests/v1/e2e/test_spec_decode.py index c4efd7548b81b..ea8d94722859b 100644 --- a/tests/v1/e2e/test_spec_decode.py +++ b/tests/v1/e2e/test_spec_decode.py @@ -8,7 +8,7 @@ from typing import Any, Union import pytest import torch -from tests.utils import get_attn_backend_list_based_on_platform +from tests.utils import get_attn_backend_list_based_on_platform, large_gpu_mark from vllm import LLM, SamplingParams from vllm.assets.base import VLLM_S3_BUCKET_URL from vllm.assets.image import VLM_IMAGES_DIR @@ -88,69 +88,66 @@ def test_ngram_correctness( Compare the outputs of an original LLM and a speculative LLM should be the same when using ngram speculative decoding. ''' - with monkeypatch.context() as m: - m.setenv("VLLM_USE_V1", "1") - test_prompts = get_test_prompts(mm_enabled=False) + test_prompts = get_test_prompts(mm_enabled=False) - ref_llm = LLM(model=model_name, max_model_len=1024) - ref_outputs = ref_llm.chat(test_prompts, sampling_config) - del ref_llm - torch.cuda.empty_cache() - cleanup_dist_env_and_memory() + ref_llm = LLM(model=model_name, max_model_len=1024) + ref_outputs = ref_llm.chat(test_prompts, sampling_config) + del ref_llm + torch.cuda.empty_cache() + cleanup_dist_env_and_memory() - spec_llm = LLM( - model=model_name, - speculative_config={ - "method": "ngram", - "prompt_lookup_max": 5, - "prompt_lookup_min": 3, - "num_speculative_tokens": 3, - }, - max_model_len=1024, - ) - spec_outputs = spec_llm.chat(test_prompts, sampling_config) - matches = 0 - misses = 0 - for ref_output, spec_output in zip(ref_outputs, spec_outputs): - if ref_output.outputs[0].text == spec_output.outputs[0].text: - matches += 1 - else: - misses += 1 - print(f"ref_output: {ref_output.outputs[0].text}") - print(f"spec_output: {spec_output.outputs[0].text}") + spec_llm = LLM( + model=model_name, + speculative_config={ + "method": "ngram", + "prompt_lookup_max": 5, + "prompt_lookup_min": 3, + "num_speculative_tokens": 3, + }, + max_model_len=1024, + ) + spec_outputs = spec_llm.chat(test_prompts, sampling_config) + matches = 0 + misses = 0 + for ref_output, spec_output in zip(ref_outputs, spec_outputs): + if ref_output.outputs[0].text == spec_output.outputs[0].text: + matches += 1 + else: + misses += 1 + print(f"ref_output: {ref_output.outputs[0].text}") + print(f"spec_output: {spec_output.outputs[0].text}") - # Heuristic: expect at least 66% of the prompts to match exactly - # Upon failure, inspect the outputs to check for inaccuracy. - assert matches >= int(0.66 * len(ref_outputs)) - del spec_llm - torch.cuda.empty_cache() - cleanup_dist_env_and_memory() + # Heuristic: expect at least 66% of the prompts to match exactly + # Upon failure, inspect the outputs to check for inaccuracy. + assert matches >= int(0.66 * len(ref_outputs)) + del spec_llm + torch.cuda.empty_cache() + cleanup_dist_env_and_memory() -@pytest.mark.parametrize(["model_setup", "mm_enabled"], [ - (("eagle3", "Qwen/Qwen3-8B", "AngelSlim/Qwen3-8B_eagle3", 1), False), - (("eagle", "meta-llama/Llama-3.1-8B-Instruct", - "yuhuili/EAGLE-LLaMA3.1-Instruct-8B", 1), False), - (("eagle3", "meta-llama/Llama-3.1-8B-Instruct", - "yuhuili/EAGLE3-LLaMA3.1-Instruct-8B", 1), False), - pytest.param( - ("eagle", "meta-llama/Llama-4-Scout-17B-16E-Instruct", - "morgendave/EAGLE-Llama-4-Scout-17B-16E-Instruct", 4), - False, - marks=pytest.mark.skip(reason="Skipping due to CI OOM issues")), - pytest.param( - ("eagle", "meta-llama/Llama-4-Scout-17B-16E-Instruct", - "morgendave/EAGLE-Llama-4-Scout-17B-16E-Instruct", 4), - True, - marks=pytest.mark.skip(reason="Skipping due to CI OOM issues")), - (("eagle", "eagle618/deepseek-v3-random", - "eagle618/eagle-deepseek-v3-random", 1), False), -], - ids=[ - "qwen3_eagle3", "llama3_eagle", "llama3_eagle3", - "llama4_eagle", "llama4_eagle_mm", - "deepseek_eagle" - ]) +@pytest.mark.parametrize( + ["model_setup", "mm_enabled"], + [ + (("eagle3", "Qwen/Qwen3-8B", "AngelSlim/Qwen3-8B_eagle3", 1), False), + (("eagle", "meta-llama/Llama-3.1-8B-Instruct", + "yuhuili/EAGLE-LLaMA3.1-Instruct-8B", 1), False), + (("eagle3", "meta-llama/Llama-3.1-8B-Instruct", + "yuhuili/EAGLE3-LLaMA3.1-Instruct-8B", 1), False), + pytest.param(("eagle", "meta-llama/Llama-4-Scout-17B-16E-Instruct", + "morgendave/EAGLE-Llama-4-Scout-17B-16E-Instruct", 4), + False, + marks=large_gpu_mark(min_gb=80)), # works on 4x H100 + pytest.param(("eagle", "meta-llama/Llama-4-Scout-17B-16E-Instruct", + "morgendave/EAGLE-Llama-4-Scout-17B-16E-Instruct", 4), + True, + marks=large_gpu_mark(min_gb=80)), # works on 4x H100 + (("eagle", "eagle618/deepseek-v3-random", + "eagle618/eagle-deepseek-v3-random", 1), False), + ], + ids=[ + "qwen3_eagle3", "llama3_eagle", "llama3_eagle3", "llama4_eagle", + "llama4_eagle_mm", "deepseek_eagle" + ]) @pytest.mark.parametrize("attn_backend", get_attn_backend_list_based_on_platform()) def test_eagle_correctness( @@ -174,9 +171,14 @@ def test_eagle_correctness( model_setup: (method, model_name, eagle_model_name, tp_size) ''' with monkeypatch.context() as m: - m.setenv("VLLM_USE_V1", "1") - m.setenv("VLLM_MLA_DISABLE", "1") - m.setenv("VLLM_ATTENTION_BACKEND", attn_backend) + if "Llama-4-Scout" in model_setup[1] and attn_backend == "FLASH_ATTN": + # Scout requires default backend selection + # because vision encoder has head_dim 88 being incompatible + # with FLASH_ATTN and needs to fall back to Flex Attn + pass + else: + m.setenv("VLLM_MLA_DISABLE", "1") + m.setenv("VLLM_ATTENTION_BACKEND", attn_backend) if (attn_backend == "TRITON_ATTN" and not current_platform.is_rocm()): pytest.skip("TRITON_ATTN does not support " diff --git a/vllm/v1/spec_decode/eagle.py b/vllm/v1/spec_decode/eagle.py index 57da8346f497f..394df48b4153f 100644 --- a/vllm/v1/spec_decode/eagle.py +++ b/vllm/v1/spec_decode/eagle.py @@ -804,6 +804,20 @@ class EagleProposer: self.attn_layer_names = list(draft_attn_layer_names) + if self.is_multimodal_model: + # Even if the target model is multimodal, we can also use + # text-only draft models + try: + dummy_input_ids = torch.tensor([[1]], + device=self.input_ids.device) + self.model.get_input_embeddings(dummy_input_ids, + multimodal_embeddings=None) + except (NotImplementedError, AttributeError, TypeError): + logger.warning( + "Draft model does not support multimodal inputs, " + "falling back to text-only mode") + self.is_multimodal_model = False + if supports_multimodal(target_model): # handle multimodality self.model.config.image_token_index = (