diff --git a/tests/kernels/attention/test_attention_selector.py b/tests/kernels/attention/test_attention_selector.py index ea5c8013a131..f3e64155703c 100644 --- a/tests/kernels/attention/test_attention_selector.py +++ b/tests/kernels/attention/test_attention_selector.py @@ -183,34 +183,6 @@ def test_env( assert backend.get_name() == expected -@pytest.mark.parametrize("device", ["cpu", "cuda"]) -@pytest.mark.parametrize("use_v1", [True, False]) -def test_fp32_fallback( - device: str, - use_v1: bool, - monkeypatch: pytest.MonkeyPatch, -): - """Test attention backend selection with fp32.""" - with monkeypatch.context() as m: - m.setenv("VLLM_USE_V1", "1" if use_v1 else "0") - - if device == "cpu": - with patch("vllm.attention.selector.current_platform", - CpuPlatform()): - backend = get_attn_backend(16, torch.float32, torch.float32, - 16, False) - assert (backend.get_name() == "TORCH_SDPA_VLLM_V1" - if use_v1 else "TORCH_SDPA") - - elif device == "cuda": - with patch("vllm.attention.selector.current_platform", - CudaPlatform()): - backend = get_attn_backend(16, torch.float32, torch.float32, - 16, False) - assert (backend.get_name() == "FLEX_ATTENTION" - if use_v1 else "XFORMERS") - - def test_flash_attn(monkeypatch: pytest.MonkeyPatch): """Test FlashAttn validation.""" # TODO: When testing for v1, pipe in `use_v1` as an argument to diff --git a/tests/v1/worker/test_gpu_model_runner.py b/tests/v1/worker/test_gpu_model_runner.py index fa0bab71b954..994432dfd593 100644 --- a/tests/v1/worker/test_gpu_model_runner.py +++ b/tests/v1/worker/test_gpu_model_runner.py @@ -4,7 +4,6 @@ import random import pytest -import torch from vllm.attention import Attention from vllm.config import (CacheConfig, ModelConfig, ParallelConfig, @@ -400,7 +399,6 @@ def test_load_model_weights_inplace(dist_init, model_runner, model_runner_2): def test_init_kv_cache_with_kv_sharing_invalid_target_layer_order(): - torch.set_default_dtype(torch.float16) layer_0 = "model.layers.0.self_attn.attn" layer_1 = "model.layers.1.self_attn.attn" error_msg = f"{layer_1} must come before the current layer" @@ -429,7 +427,6 @@ def test_init_kv_cache_with_kv_sharing_invalid_target_layer_order(): def test_init_kv_cache_with_kv_sharing_target_layer_not_exist(): - torch.set_default_dtype(torch.float16) layer_0 = "model.layers.0.self_attn.attn" layer_1 = "model.layers.1.self_attn.attn" invalid_layer = "model.layers.0.cross_attn.attn" @@ -458,7 +455,6 @@ def test_init_kv_cache_with_kv_sharing_target_layer_not_exist(): def test_init_kv_cache_with_kv_sharing_target_same_as_current(): - torch.set_default_dtype(torch.float16) layer_0 = "model.layers.0.self_attn.attn" layer_1 = "model.layers.1.self_attn.attn" error_msg = f"{layer_1} cannot be the same as the current layer" @@ -487,7 +483,6 @@ def test_init_kv_cache_with_kv_sharing_target_same_as_current(): def test_init_kv_cache_without_kv_sharing(): - torch.set_default_dtype(torch.float16) layer_0 = "model.layers.0.self_attn.attn" layer_1 = "model.layers.1.self_attn.attn" vllm_config = get_vllm_config() @@ -555,7 +550,6 @@ def test_init_kv_cache_without_kv_sharing(): def test_init_kv_cache_with_kv_sharing_valid(): - torch.set_default_dtype(torch.float16) layer_0 = "model.layers.0.self_attn.attn" layer_1 = "model.layers.1.self_attn.attn" vllm_config = get_vllm_config() diff --git a/vllm/engine/arg_utils.py b/vllm/engine/arg_utils.py index 837c9b1b1e94..4ce1b41e4f87 100644 --- a/vllm/engine/arg_utils.py +++ b/vllm/engine/arg_utils.py @@ -1337,6 +1337,13 @@ class EngineArgs: recommend_to_remove=False) return False + # Only Fp16 and Bf16 dtypes since we only support FA. + V1_SUPPORTED_DTYPES = [torch.bfloat16, torch.float16] + if model_config.dtype not in V1_SUPPORTED_DTYPES: + _raise_or_fallback(feature_name=f"--dtype {model_config.dtype}", + recommend_to_remove=False) + return False + # No Embedding Models so far. if model_config.task not in ["generate"]: _raise_or_fallback(feature_name=f"--task {model_config.task}", diff --git a/vllm/platforms/cuda.py b/vllm/platforms/cuda.py index 20c06c5c91ee..48d1aacba185 100644 --- a/vllm/platforms/cuda.py +++ b/vllm/platforms/cuda.py @@ -233,10 +233,6 @@ class CudaPlatformBase(Platform): logger.info_once("Using Triton backend on V1 engine.") return ("vllm.v1.attention.backends." "triton_attn.TritonAttentionBackend") - if dtype not in (torch.float16, torch.bfloat16): - logger.info_once( - f"Using FlexAttenion backend for {dtype} on V1 engine.") - return "vllm.v1.attention.backends.flex_attention.FlexAttentionBackend" # noqa: E501 if cls.is_device_capability(100): # Prefer FlashInfer for V1 on Blackwell GPUs if installed try: