mirror of
https://git.datalinker.icu/vllm-project/vllm.git
synced 2025-12-15 03:04:56 +08:00
Revert "[v1] Add fp32 support to v1 engine through flex attn" (#19404)
This commit is contained in:
parent
9368cc90b2
commit
5f1ac1e1d1
@ -183,34 +183,6 @@ def test_env(
|
|||||||
assert backend.get_name() == expected
|
assert backend.get_name() == expected
|
||||||
|
|
||||||
|
|
||||||
@pytest.mark.parametrize("device", ["cpu", "cuda"])
|
|
||||||
@pytest.mark.parametrize("use_v1", [True, False])
|
|
||||||
def test_fp32_fallback(
|
|
||||||
device: str,
|
|
||||||
use_v1: bool,
|
|
||||||
monkeypatch: pytest.MonkeyPatch,
|
|
||||||
):
|
|
||||||
"""Test attention backend selection with fp32."""
|
|
||||||
with monkeypatch.context() as m:
|
|
||||||
m.setenv("VLLM_USE_V1", "1" if use_v1 else "0")
|
|
||||||
|
|
||||||
if device == "cpu":
|
|
||||||
with patch("vllm.attention.selector.current_platform",
|
|
||||||
CpuPlatform()):
|
|
||||||
backend = get_attn_backend(16, torch.float32, torch.float32,
|
|
||||||
16, False)
|
|
||||||
assert (backend.get_name() == "TORCH_SDPA_VLLM_V1"
|
|
||||||
if use_v1 else "TORCH_SDPA")
|
|
||||||
|
|
||||||
elif device == "cuda":
|
|
||||||
with patch("vllm.attention.selector.current_platform",
|
|
||||||
CudaPlatform()):
|
|
||||||
backend = get_attn_backend(16, torch.float32, torch.float32,
|
|
||||||
16, False)
|
|
||||||
assert (backend.get_name() == "FLEX_ATTENTION"
|
|
||||||
if use_v1 else "XFORMERS")
|
|
||||||
|
|
||||||
|
|
||||||
def test_flash_attn(monkeypatch: pytest.MonkeyPatch):
|
def test_flash_attn(monkeypatch: pytest.MonkeyPatch):
|
||||||
"""Test FlashAttn validation."""
|
"""Test FlashAttn validation."""
|
||||||
# TODO: When testing for v1, pipe in `use_v1` as an argument to
|
# TODO: When testing for v1, pipe in `use_v1` as an argument to
|
||||||
|
|||||||
@ -4,7 +4,6 @@
|
|||||||
import random
|
import random
|
||||||
|
|
||||||
import pytest
|
import pytest
|
||||||
import torch
|
|
||||||
|
|
||||||
from vllm.attention import Attention
|
from vllm.attention import Attention
|
||||||
from vllm.config import (CacheConfig, ModelConfig, ParallelConfig,
|
from vllm.config import (CacheConfig, ModelConfig, ParallelConfig,
|
||||||
@ -400,7 +399,6 @@ def test_load_model_weights_inplace(dist_init, model_runner, model_runner_2):
|
|||||||
|
|
||||||
|
|
||||||
def test_init_kv_cache_with_kv_sharing_invalid_target_layer_order():
|
def test_init_kv_cache_with_kv_sharing_invalid_target_layer_order():
|
||||||
torch.set_default_dtype(torch.float16)
|
|
||||||
layer_0 = "model.layers.0.self_attn.attn"
|
layer_0 = "model.layers.0.self_attn.attn"
|
||||||
layer_1 = "model.layers.1.self_attn.attn"
|
layer_1 = "model.layers.1.self_attn.attn"
|
||||||
error_msg = f"{layer_1} must come before the current layer"
|
error_msg = f"{layer_1} must come before the current layer"
|
||||||
@ -429,7 +427,6 @@ def test_init_kv_cache_with_kv_sharing_invalid_target_layer_order():
|
|||||||
|
|
||||||
|
|
||||||
def test_init_kv_cache_with_kv_sharing_target_layer_not_exist():
|
def test_init_kv_cache_with_kv_sharing_target_layer_not_exist():
|
||||||
torch.set_default_dtype(torch.float16)
|
|
||||||
layer_0 = "model.layers.0.self_attn.attn"
|
layer_0 = "model.layers.0.self_attn.attn"
|
||||||
layer_1 = "model.layers.1.self_attn.attn"
|
layer_1 = "model.layers.1.self_attn.attn"
|
||||||
invalid_layer = "model.layers.0.cross_attn.attn"
|
invalid_layer = "model.layers.0.cross_attn.attn"
|
||||||
@ -458,7 +455,6 @@ def test_init_kv_cache_with_kv_sharing_target_layer_not_exist():
|
|||||||
|
|
||||||
|
|
||||||
def test_init_kv_cache_with_kv_sharing_target_same_as_current():
|
def test_init_kv_cache_with_kv_sharing_target_same_as_current():
|
||||||
torch.set_default_dtype(torch.float16)
|
|
||||||
layer_0 = "model.layers.0.self_attn.attn"
|
layer_0 = "model.layers.0.self_attn.attn"
|
||||||
layer_1 = "model.layers.1.self_attn.attn"
|
layer_1 = "model.layers.1.self_attn.attn"
|
||||||
error_msg = f"{layer_1} cannot be the same as the current layer"
|
error_msg = f"{layer_1} cannot be the same as the current layer"
|
||||||
@ -487,7 +483,6 @@ def test_init_kv_cache_with_kv_sharing_target_same_as_current():
|
|||||||
|
|
||||||
|
|
||||||
def test_init_kv_cache_without_kv_sharing():
|
def test_init_kv_cache_without_kv_sharing():
|
||||||
torch.set_default_dtype(torch.float16)
|
|
||||||
layer_0 = "model.layers.0.self_attn.attn"
|
layer_0 = "model.layers.0.self_attn.attn"
|
||||||
layer_1 = "model.layers.1.self_attn.attn"
|
layer_1 = "model.layers.1.self_attn.attn"
|
||||||
vllm_config = get_vllm_config()
|
vllm_config = get_vllm_config()
|
||||||
@ -555,7 +550,6 @@ def test_init_kv_cache_without_kv_sharing():
|
|||||||
|
|
||||||
|
|
||||||
def test_init_kv_cache_with_kv_sharing_valid():
|
def test_init_kv_cache_with_kv_sharing_valid():
|
||||||
torch.set_default_dtype(torch.float16)
|
|
||||||
layer_0 = "model.layers.0.self_attn.attn"
|
layer_0 = "model.layers.0.self_attn.attn"
|
||||||
layer_1 = "model.layers.1.self_attn.attn"
|
layer_1 = "model.layers.1.self_attn.attn"
|
||||||
vllm_config = get_vllm_config()
|
vllm_config = get_vllm_config()
|
||||||
|
|||||||
@ -1337,6 +1337,13 @@ class EngineArgs:
|
|||||||
recommend_to_remove=False)
|
recommend_to_remove=False)
|
||||||
return False
|
return False
|
||||||
|
|
||||||
|
# Only Fp16 and Bf16 dtypes since we only support FA.
|
||||||
|
V1_SUPPORTED_DTYPES = [torch.bfloat16, torch.float16]
|
||||||
|
if model_config.dtype not in V1_SUPPORTED_DTYPES:
|
||||||
|
_raise_or_fallback(feature_name=f"--dtype {model_config.dtype}",
|
||||||
|
recommend_to_remove=False)
|
||||||
|
return False
|
||||||
|
|
||||||
# No Embedding Models so far.
|
# No Embedding Models so far.
|
||||||
if model_config.task not in ["generate"]:
|
if model_config.task not in ["generate"]:
|
||||||
_raise_or_fallback(feature_name=f"--task {model_config.task}",
|
_raise_or_fallback(feature_name=f"--task {model_config.task}",
|
||||||
|
|||||||
@ -233,10 +233,6 @@ class CudaPlatformBase(Platform):
|
|||||||
logger.info_once("Using Triton backend on V1 engine.")
|
logger.info_once("Using Triton backend on V1 engine.")
|
||||||
return ("vllm.v1.attention.backends."
|
return ("vllm.v1.attention.backends."
|
||||||
"triton_attn.TritonAttentionBackend")
|
"triton_attn.TritonAttentionBackend")
|
||||||
if dtype not in (torch.float16, torch.bfloat16):
|
|
||||||
logger.info_once(
|
|
||||||
f"Using FlexAttenion backend for {dtype} on V1 engine.")
|
|
||||||
return "vllm.v1.attention.backends.flex_attention.FlexAttentionBackend" # noqa: E501
|
|
||||||
if cls.is_device_capability(100):
|
if cls.is_device_capability(100):
|
||||||
# Prefer FlashInfer for V1 on Blackwell GPUs if installed
|
# Prefer FlashInfer for V1 on Blackwell GPUs if installed
|
||||||
try:
|
try:
|
||||||
|
|||||||
Loading…
x
Reference in New Issue
Block a user