diff --git a/tests/compile/test_full_graph.py b/tests/compile/test_full_graph.py index 71f90f6d8d3ee..b4e5e56ac9fe6 100644 --- a/tests/compile/test_full_graph.py +++ b/tests/compile/test_full_graph.py @@ -10,6 +10,7 @@ import torch from tests.quantization.utils import is_quant_method_supported from vllm import LLM, SamplingParams +from vllm.attention.backends.registry import AttentionBackendEnum from vllm.config import CompilationConfig, CompilationMode, CUDAGraphMode, PassConfig from vllm.platforms import current_platform from vllm.utils.torch_utils import is_torch_equal_or_newer @@ -184,13 +185,24 @@ def test_custom_compile_config( [CompilationMode.NONE, CompilationMode.VLLM_COMPILE], ) @pytest.mark.parametrize( - "model", + "model, backend", [ - "Qwen/Qwen2-0.5B", # Standard attention model - "deepseek-ai/DeepSeek-V2-Lite", # MLA (Multi-head Latent Attention) model + ("Qwen/Qwen2-0.5B", None), # Standard attention model + ( + "deepseek-ai/DeepSeek-V2-Lite", + AttentionBackendEnum.FLASHINFER_MLA, + ), # MLA (Multi-head Latent Attention) model ], ) -def test_fp8_kv_scale_compile(compilation_mode: int, model: str): +def test_fp8_kv_scale_compile( + monkeypatch: pytest.MonkeyPatch, + compilation_mode: int, + model: str, + backend: AttentionBackendEnum | None, +): + if backend: + monkeypatch.setenv("VLLM_ATTENTION_BACKEND", backend.name) + model_kwargs = { "quantization": "fp8", "kv_cache_dtype": "fp8_e4m3",