Use FLASHINFER MLA backend when testing fp8_kv_scale_compile (#28491)

Signed-off-by: adabeyta <aabeyta@redhat.com>
This commit is contained in:
Adrian Abeyta 2025-11-11 18:34:58 -06:00 committed by GitHub
parent 412e153df5
commit d23539549a
No known key found for this signature in database
GPG Key ID: B5690EEEBB952194

View File

@ -10,6 +10,7 @@ import torch
from tests.quantization.utils import is_quant_method_supported
from vllm import LLM, SamplingParams
from vllm.attention.backends.registry import AttentionBackendEnum
from vllm.config import CompilationConfig, CompilationMode, CUDAGraphMode, PassConfig
from vllm.platforms import current_platform
from vllm.utils.torch_utils import is_torch_equal_or_newer
@ -184,13 +185,24 @@ def test_custom_compile_config(
[CompilationMode.NONE, CompilationMode.VLLM_COMPILE],
)
@pytest.mark.parametrize(
"model",
"model, backend",
[
"Qwen/Qwen2-0.5B", # Standard attention model
"deepseek-ai/DeepSeek-V2-Lite", # MLA (Multi-head Latent Attention) model
("Qwen/Qwen2-0.5B", None), # Standard attention model
(
"deepseek-ai/DeepSeek-V2-Lite",
AttentionBackendEnum.FLASHINFER_MLA,
), # MLA (Multi-head Latent Attention) model
],
)
def test_fp8_kv_scale_compile(compilation_mode: int, model: str):
def test_fp8_kv_scale_compile(
monkeypatch: pytest.MonkeyPatch,
compilation_mode: int,
model: str,
backend: AttentionBackendEnum | None,
):
if backend:
monkeypatch.setenv("VLLM_ATTENTION_BACKEND", backend.name)
model_kwargs = {
"quantization": "fp8",
"kv_cache_dtype": "fp8_e4m3",