mirror of
https://git.datalinker.icu/vllm-project/vllm.git
synced 2026-06-05 15:02:22 +08:00
Use FLASHINFER MLA backend when testing fp8_kv_scale_compile (#28491)
Signed-off-by: adabeyta <aabeyta@redhat.com>
This commit is contained in:
parent
412e153df5
commit
d23539549a
@ -10,6 +10,7 @@ import torch
|
|||||||
|
|
||||||
from tests.quantization.utils import is_quant_method_supported
|
from tests.quantization.utils import is_quant_method_supported
|
||||||
from vllm import LLM, SamplingParams
|
from vllm import LLM, SamplingParams
|
||||||
|
from vllm.attention.backends.registry import AttentionBackendEnum
|
||||||
from vllm.config import CompilationConfig, CompilationMode, CUDAGraphMode, PassConfig
|
from vllm.config import CompilationConfig, CompilationMode, CUDAGraphMode, PassConfig
|
||||||
from vllm.platforms import current_platform
|
from vllm.platforms import current_platform
|
||||||
from vllm.utils.torch_utils import is_torch_equal_or_newer
|
from vllm.utils.torch_utils import is_torch_equal_or_newer
|
||||||
@ -184,13 +185,24 @@ def test_custom_compile_config(
|
|||||||
[CompilationMode.NONE, CompilationMode.VLLM_COMPILE],
|
[CompilationMode.NONE, CompilationMode.VLLM_COMPILE],
|
||||||
)
|
)
|
||||||
@pytest.mark.parametrize(
|
@pytest.mark.parametrize(
|
||||||
"model",
|
"model, backend",
|
||||||
[
|
[
|
||||||
"Qwen/Qwen2-0.5B", # Standard attention model
|
("Qwen/Qwen2-0.5B", None), # Standard attention model
|
||||||
"deepseek-ai/DeepSeek-V2-Lite", # MLA (Multi-head Latent Attention) model
|
(
|
||||||
|
"deepseek-ai/DeepSeek-V2-Lite",
|
||||||
|
AttentionBackendEnum.FLASHINFER_MLA,
|
||||||
|
), # MLA (Multi-head Latent Attention) model
|
||||||
],
|
],
|
||||||
)
|
)
|
||||||
def test_fp8_kv_scale_compile(compilation_mode: int, model: str):
|
def test_fp8_kv_scale_compile(
|
||||||
|
monkeypatch: pytest.MonkeyPatch,
|
||||||
|
compilation_mode: int,
|
||||||
|
model: str,
|
||||||
|
backend: AttentionBackendEnum | None,
|
||||||
|
):
|
||||||
|
if backend:
|
||||||
|
monkeypatch.setenv("VLLM_ATTENTION_BACKEND", backend.name)
|
||||||
|
|
||||||
model_kwargs = {
|
model_kwargs = {
|
||||||
"quantization": "fp8",
|
"quantization": "fp8",
|
||||||
"kv_cache_dtype": "fp8_e4m3",
|
"kv_cache_dtype": "fp8_e4m3",
|
||||||
|
|||||||
Loading…
x
Reference in New Issue
Block a user