From b8089195b450fb66b0336147854bdf2a669fc4f1 Mon Sep 17 00:00:00 2001 From: Isotr0py Date: Mon, 9 Jun 2025 22:10:44 +0800 Subject: [PATCH] [v1] Add fp32 support to v1 engine through flex attn (#19319) Signed-off-by: Isotr0py <2037008807@qq.com> Signed-off-by: Isotr0py --- .../attention/test_attention_selector.py | 28 +++++++++++++++++++ tests/v1/worker/test_gpu_model_runner.py | 6 ++++ vllm/engine/arg_utils.py | 7 ----- vllm/platforms/cuda.py | 4 +++ 4 files changed, 38 insertions(+), 7 deletions(-) diff --git a/tests/kernels/attention/test_attention_selector.py b/tests/kernels/attention/test_attention_selector.py index f3e64155703c2..ea5c8013a1319 100644 --- a/tests/kernels/attention/test_attention_selector.py +++ b/tests/kernels/attention/test_attention_selector.py @@ -183,6 +183,34 @@ def test_env( assert backend.get_name() == expected +@pytest.mark.parametrize("device", ["cpu", "cuda"]) +@pytest.mark.parametrize("use_v1", [True, False]) +def test_fp32_fallback( + device: str, + use_v1: bool, + monkeypatch: pytest.MonkeyPatch, +): + """Test attention backend selection with fp32.""" + with monkeypatch.context() as m: + m.setenv("VLLM_USE_V1", "1" if use_v1 else "0") + + if device == "cpu": + with patch("vllm.attention.selector.current_platform", + CpuPlatform()): + backend = get_attn_backend(16, torch.float32, torch.float32, + 16, False) + assert (backend.get_name() == "TORCH_SDPA_VLLM_V1" + if use_v1 else "TORCH_SDPA") + + elif device == "cuda": + with patch("vllm.attention.selector.current_platform", + CudaPlatform()): + backend = get_attn_backend(16, torch.float32, torch.float32, + 16, False) + assert (backend.get_name() == "FLEX_ATTENTION" + if use_v1 else "XFORMERS") + + def test_flash_attn(monkeypatch: pytest.MonkeyPatch): """Test FlashAttn validation.""" # TODO: When testing for v1, pipe in `use_v1` as an argument to diff --git a/tests/v1/worker/test_gpu_model_runner.py b/tests/v1/worker/test_gpu_model_runner.py index caacb1652e9a2..3d51b53df2ce9 100644 --- a/tests/v1/worker/test_gpu_model_runner.py +++ b/tests/v1/worker/test_gpu_model_runner.py @@ -4,6 +4,7 @@ import random import pytest +import torch from vllm.attention import Attention from vllm.config import (CacheConfig, ModelConfig, ParallelConfig, @@ -399,6 +400,7 @@ def test_load_model_weights_inplace(dist_init, model_runner, model_runner_2): def test_init_kv_cache_with_kv_sharing_invalid_target_layer_order(): + torch.set_default_dtype(torch.float16) layer_0 = "model.layers.0.self_attn.attn" layer_1 = "model.layers.1.self_attn.attn" error_msg = f"{layer_1} must come before the current layer" @@ -427,6 +429,7 @@ def test_init_kv_cache_with_kv_sharing_invalid_target_layer_order(): def test_init_kv_cache_with_kv_sharing_target_layer_not_exist(): + torch.set_default_dtype(torch.float16) layer_0 = "model.layers.0.self_attn.attn" layer_1 = "model.layers.1.self_attn.attn" invalid_layer = "model.layers.0.cross_attn.attn" @@ -455,6 +458,7 @@ def test_init_kv_cache_with_kv_sharing_target_layer_not_exist(): def test_init_kv_cache_with_kv_sharing_target_same_as_current(): + torch.set_default_dtype(torch.float16) layer_0 = "model.layers.0.self_attn.attn" layer_1 = "model.layers.1.self_attn.attn" error_msg = f"{layer_1} cannot be the same as the current layer" @@ -483,6 +487,7 @@ def test_init_kv_cache_with_kv_sharing_target_same_as_current(): def test_init_kv_cache_without_kv_sharing(): + torch.set_default_dtype(torch.float16) layer_0 = "model.layers.0.self_attn.attn" layer_1 = "model.layers.1.self_attn.attn" vllm_config = get_vllm_config() @@ -550,6 +555,7 @@ def test_init_kv_cache_without_kv_sharing(): def test_init_kv_cache_with_kv_sharing_valid(): + torch.set_default_dtype(torch.float16) layer_0 = "model.layers.0.self_attn.attn" layer_1 = "model.layers.1.self_attn.attn" vllm_config = get_vllm_config() diff --git a/vllm/engine/arg_utils.py b/vllm/engine/arg_utils.py index 4ce1b41e4f87f..837c9b1b1e942 100644 --- a/vllm/engine/arg_utils.py +++ b/vllm/engine/arg_utils.py @@ -1337,13 +1337,6 @@ class EngineArgs: recommend_to_remove=False) return False - # Only Fp16 and Bf16 dtypes since we only support FA. - V1_SUPPORTED_DTYPES = [torch.bfloat16, torch.float16] - if model_config.dtype not in V1_SUPPORTED_DTYPES: - _raise_or_fallback(feature_name=f"--dtype {model_config.dtype}", - recommend_to_remove=False) - return False - # No Embedding Models so far. if model_config.task not in ["generate"]: _raise_or_fallback(feature_name=f"--task {model_config.task}", diff --git a/vllm/platforms/cuda.py b/vllm/platforms/cuda.py index 48d1aacba1858..20c06c5c91eee 100644 --- a/vllm/platforms/cuda.py +++ b/vllm/platforms/cuda.py @@ -233,6 +233,10 @@ class CudaPlatformBase(Platform): logger.info_once("Using Triton backend on V1 engine.") return ("vllm.v1.attention.backends." "triton_attn.TritonAttentionBackend") + if dtype not in (torch.float16, torch.bfloat16): + logger.info_once( + f"Using FlexAttenion backend for {dtype} on V1 engine.") + return "vllm.v1.attention.backends.flex_attention.FlexAttentionBackend" # noqa: E501 if cls.is_device_capability(100): # Prefer FlashInfer for V1 on Blackwell GPUs if installed try: