mirror of
https://git.datalinker.icu/vllm-project/vllm.git
synced 2026-05-10 01:54:30 +08:00
[Spec decode] automatically disable mm for text-only draft models (#25667)
Signed-off-by: Jonas Kuebler <kuebj@amazon.com>
This commit is contained in:
parent
4e33a7ea85
commit
6f5c0931c1
@ -8,7 +8,7 @@ from typing import Any, Union
|
|||||||
import pytest
|
import pytest
|
||||||
import torch
|
import torch
|
||||||
|
|
||||||
from tests.utils import get_attn_backend_list_based_on_platform
|
from tests.utils import get_attn_backend_list_based_on_platform, large_gpu_mark
|
||||||
from vllm import LLM, SamplingParams
|
from vllm import LLM, SamplingParams
|
||||||
from vllm.assets.base import VLLM_S3_BUCKET_URL
|
from vllm.assets.base import VLLM_S3_BUCKET_URL
|
||||||
from vllm.assets.image import VLM_IMAGES_DIR
|
from vllm.assets.image import VLM_IMAGES_DIR
|
||||||
@ -88,69 +88,66 @@ def test_ngram_correctness(
|
|||||||
Compare the outputs of an original LLM and a speculative LLM
|
Compare the outputs of an original LLM and a speculative LLM
|
||||||
should be the same when using ngram speculative decoding.
|
should be the same when using ngram speculative decoding.
|
||||||
'''
|
'''
|
||||||
with monkeypatch.context() as m:
|
test_prompts = get_test_prompts(mm_enabled=False)
|
||||||
m.setenv("VLLM_USE_V1", "1")
|
|
||||||
test_prompts = get_test_prompts(mm_enabled=False)
|
|
||||||
|
|
||||||
ref_llm = LLM(model=model_name, max_model_len=1024)
|
ref_llm = LLM(model=model_name, max_model_len=1024)
|
||||||
ref_outputs = ref_llm.chat(test_prompts, sampling_config)
|
ref_outputs = ref_llm.chat(test_prompts, sampling_config)
|
||||||
del ref_llm
|
del ref_llm
|
||||||
torch.cuda.empty_cache()
|
torch.cuda.empty_cache()
|
||||||
cleanup_dist_env_and_memory()
|
cleanup_dist_env_and_memory()
|
||||||
|
|
||||||
spec_llm = LLM(
|
spec_llm = LLM(
|
||||||
model=model_name,
|
model=model_name,
|
||||||
speculative_config={
|
speculative_config={
|
||||||
"method": "ngram",
|
"method": "ngram",
|
||||||
"prompt_lookup_max": 5,
|
"prompt_lookup_max": 5,
|
||||||
"prompt_lookup_min": 3,
|
"prompt_lookup_min": 3,
|
||||||
"num_speculative_tokens": 3,
|
"num_speculative_tokens": 3,
|
||||||
},
|
},
|
||||||
max_model_len=1024,
|
max_model_len=1024,
|
||||||
)
|
)
|
||||||
spec_outputs = spec_llm.chat(test_prompts, sampling_config)
|
spec_outputs = spec_llm.chat(test_prompts, sampling_config)
|
||||||
matches = 0
|
matches = 0
|
||||||
misses = 0
|
misses = 0
|
||||||
for ref_output, spec_output in zip(ref_outputs, spec_outputs):
|
for ref_output, spec_output in zip(ref_outputs, spec_outputs):
|
||||||
if ref_output.outputs[0].text == spec_output.outputs[0].text:
|
if ref_output.outputs[0].text == spec_output.outputs[0].text:
|
||||||
matches += 1
|
matches += 1
|
||||||
else:
|
else:
|
||||||
misses += 1
|
misses += 1
|
||||||
print(f"ref_output: {ref_output.outputs[0].text}")
|
print(f"ref_output: {ref_output.outputs[0].text}")
|
||||||
print(f"spec_output: {spec_output.outputs[0].text}")
|
print(f"spec_output: {spec_output.outputs[0].text}")
|
||||||
|
|
||||||
# Heuristic: expect at least 66% of the prompts to match exactly
|
# Heuristic: expect at least 66% of the prompts to match exactly
|
||||||
# Upon failure, inspect the outputs to check for inaccuracy.
|
# Upon failure, inspect the outputs to check for inaccuracy.
|
||||||
assert matches >= int(0.66 * len(ref_outputs))
|
assert matches >= int(0.66 * len(ref_outputs))
|
||||||
del spec_llm
|
del spec_llm
|
||||||
torch.cuda.empty_cache()
|
torch.cuda.empty_cache()
|
||||||
cleanup_dist_env_and_memory()
|
cleanup_dist_env_and_memory()
|
||||||
|
|
||||||
|
|
||||||
@pytest.mark.parametrize(["model_setup", "mm_enabled"], [
|
@pytest.mark.parametrize(
|
||||||
(("eagle3", "Qwen/Qwen3-8B", "AngelSlim/Qwen3-8B_eagle3", 1), False),
|
["model_setup", "mm_enabled"],
|
||||||
(("eagle", "meta-llama/Llama-3.1-8B-Instruct",
|
[
|
||||||
"yuhuili/EAGLE-LLaMA3.1-Instruct-8B", 1), False),
|
(("eagle3", "Qwen/Qwen3-8B", "AngelSlim/Qwen3-8B_eagle3", 1), False),
|
||||||
(("eagle3", "meta-llama/Llama-3.1-8B-Instruct",
|
(("eagle", "meta-llama/Llama-3.1-8B-Instruct",
|
||||||
"yuhuili/EAGLE3-LLaMA3.1-Instruct-8B", 1), False),
|
"yuhuili/EAGLE-LLaMA3.1-Instruct-8B", 1), False),
|
||||||
pytest.param(
|
(("eagle3", "meta-llama/Llama-3.1-8B-Instruct",
|
||||||
("eagle", "meta-llama/Llama-4-Scout-17B-16E-Instruct",
|
"yuhuili/EAGLE3-LLaMA3.1-Instruct-8B", 1), False),
|
||||||
"morgendave/EAGLE-Llama-4-Scout-17B-16E-Instruct", 4),
|
pytest.param(("eagle", "meta-llama/Llama-4-Scout-17B-16E-Instruct",
|
||||||
False,
|
"morgendave/EAGLE-Llama-4-Scout-17B-16E-Instruct", 4),
|
||||||
marks=pytest.mark.skip(reason="Skipping due to CI OOM issues")),
|
False,
|
||||||
pytest.param(
|
marks=large_gpu_mark(min_gb=80)), # works on 4x H100
|
||||||
("eagle", "meta-llama/Llama-4-Scout-17B-16E-Instruct",
|
pytest.param(("eagle", "meta-llama/Llama-4-Scout-17B-16E-Instruct",
|
||||||
"morgendave/EAGLE-Llama-4-Scout-17B-16E-Instruct", 4),
|
"morgendave/EAGLE-Llama-4-Scout-17B-16E-Instruct", 4),
|
||||||
True,
|
True,
|
||||||
marks=pytest.mark.skip(reason="Skipping due to CI OOM issues")),
|
marks=large_gpu_mark(min_gb=80)), # works on 4x H100
|
||||||
(("eagle", "eagle618/deepseek-v3-random",
|
(("eagle", "eagle618/deepseek-v3-random",
|
||||||
"eagle618/eagle-deepseek-v3-random", 1), False),
|
"eagle618/eagle-deepseek-v3-random", 1), False),
|
||||||
],
|
],
|
||||||
ids=[
|
ids=[
|
||||||
"qwen3_eagle3", "llama3_eagle", "llama3_eagle3",
|
"qwen3_eagle3", "llama3_eagle", "llama3_eagle3", "llama4_eagle",
|
||||||
"llama4_eagle", "llama4_eagle_mm",
|
"llama4_eagle_mm", "deepseek_eagle"
|
||||||
"deepseek_eagle"
|
])
|
||||||
])
|
|
||||||
@pytest.mark.parametrize("attn_backend",
|
@pytest.mark.parametrize("attn_backend",
|
||||||
get_attn_backend_list_based_on_platform())
|
get_attn_backend_list_based_on_platform())
|
||||||
def test_eagle_correctness(
|
def test_eagle_correctness(
|
||||||
@ -174,9 +171,14 @@ def test_eagle_correctness(
|
|||||||
model_setup: (method, model_name, eagle_model_name, tp_size)
|
model_setup: (method, model_name, eagle_model_name, tp_size)
|
||||||
'''
|
'''
|
||||||
with monkeypatch.context() as m:
|
with monkeypatch.context() as m:
|
||||||
m.setenv("VLLM_USE_V1", "1")
|
if "Llama-4-Scout" in model_setup[1] and attn_backend == "FLASH_ATTN":
|
||||||
m.setenv("VLLM_MLA_DISABLE", "1")
|
# Scout requires default backend selection
|
||||||
m.setenv("VLLM_ATTENTION_BACKEND", attn_backend)
|
# because vision encoder has head_dim 88 being incompatible
|
||||||
|
# with FLASH_ATTN and needs to fall back to Flex Attn
|
||||||
|
pass
|
||||||
|
else:
|
||||||
|
m.setenv("VLLM_MLA_DISABLE", "1")
|
||||||
|
m.setenv("VLLM_ATTENTION_BACKEND", attn_backend)
|
||||||
|
|
||||||
if (attn_backend == "TRITON_ATTN" and not current_platform.is_rocm()):
|
if (attn_backend == "TRITON_ATTN" and not current_platform.is_rocm()):
|
||||||
pytest.skip("TRITON_ATTN does not support "
|
pytest.skip("TRITON_ATTN does not support "
|
||||||
|
|||||||
@ -804,6 +804,20 @@ class EagleProposer:
|
|||||||
|
|
||||||
self.attn_layer_names = list(draft_attn_layer_names)
|
self.attn_layer_names = list(draft_attn_layer_names)
|
||||||
|
|
||||||
|
if self.is_multimodal_model:
|
||||||
|
# Even if the target model is multimodal, we can also use
|
||||||
|
# text-only draft models
|
||||||
|
try:
|
||||||
|
dummy_input_ids = torch.tensor([[1]],
|
||||||
|
device=self.input_ids.device)
|
||||||
|
self.model.get_input_embeddings(dummy_input_ids,
|
||||||
|
multimodal_embeddings=None)
|
||||||
|
except (NotImplementedError, AttributeError, TypeError):
|
||||||
|
logger.warning(
|
||||||
|
"Draft model does not support multimodal inputs, "
|
||||||
|
"falling back to text-only mode")
|
||||||
|
self.is_multimodal_model = False
|
||||||
|
|
||||||
if supports_multimodal(target_model):
|
if supports_multimodal(target_model):
|
||||||
# handle multimodality
|
# handle multimodality
|
||||||
self.model.config.image_token_index = (
|
self.model.config.image_token_index = (
|
||||||
|
|||||||
Loading…
x
Reference in New Issue
Block a user