[V0 deprecation] Remove _VLLM_V1 suffixes from attention backend names (#25489)

Signed-off-by: Matthew Bonanni <mbonanni@redhat.com>
Signed-off-by: Matthew Bonanni <mbonanni001@gmail.com>
Signed-off-by: yewentao256 <zhyanwentao@126.com>
This commit is contained in:
Matthew Bonanni 2025-09-25 13:37:50 -04:00 committed by yewentao256
parent 9659b7e78f
commit a355561291
42 changed files with 131 additions and 174 deletions

View File

@ -35,7 +35,7 @@ docker run \
python3 examples/offline_inference/basic/generate.py --model facebook/opt-125m --block-size 64 -O3 -O.cudagraph_mode=NONE python3 examples/offline_inference/basic/generate.py --model facebook/opt-125m --block-size 64 -O3 -O.cudagraph_mode=NONE
python3 examples/offline_inference/basic/generate.py --model facebook/opt-125m --block-size 64 --enforce-eager -tp 2 --distributed-executor-backend ray python3 examples/offline_inference/basic/generate.py --model facebook/opt-125m --block-size 64 --enforce-eager -tp 2 --distributed-executor-backend ray
python3 examples/offline_inference/basic/generate.py --model facebook/opt-125m --block-size 64 --enforce-eager -tp 2 --distributed-executor-backend mp python3 examples/offline_inference/basic/generate.py --model facebook/opt-125m --block-size 64 --enforce-eager -tp 2 --distributed-executor-backend mp
VLLM_ATTENTION_BACKEND=TRITON_ATTN_VLLM_V1 python3 examples/offline_inference/basic/generate.py --model facebook/opt-125m --block-size 64 --enforce-eager VLLM_ATTENTION_BACKEND=TRITON_ATTN python3 examples/offline_inference/basic/generate.py --model facebook/opt-125m --block-size 64 --enforce-eager
cd tests cd tests
pytest -v -s v1/core pytest -v -s v1/core
pytest -v -s v1/engine pytest -v -s v1/engine

View File

@ -103,7 +103,7 @@ backend_configs = {
# Triton Attention # Triton Attention
"TritonAttn": "TritonAttn":
BackendConfig(name="TritonAttn", BackendConfig(name="TritonAttn",
env_vars={"VLLM_ATTENTION_BACKEND": "TRITON_ATTN_VLLM_V1"}, env_vars={"VLLM_ATTENTION_BACKEND": "TRITON_ATTN"},
comp_config={ comp_config={
"cudagraph_mode": "FULL", "cudagraph_mode": "FULL",
}), }),

View File

@ -338,7 +338,7 @@ else:
@pytest.mark.parametrize("model_name, model_class", MODELS) @pytest.mark.parametrize("model_name, model_class", MODELS)
@pytest.mark.parametrize("backend", @pytest.mark.parametrize("backend",
[_Backend.FLASHINFER] if current_platform.is_cuda() [_Backend.FLASHINFER] if current_platform.is_cuda()
else [_Backend.TRITON_ATTN_VLLM_V1]) else [_Backend.TRITON_ATTN])
@pytest.mark.parametrize( @pytest.mark.parametrize(
"split_attention", "split_attention",
[False, True] if current_platform.is_rocm() else [False]) [False, True] if current_platform.is_rocm() else [False])

View File

@ -68,7 +68,7 @@ def default_server_args(with_tool_parser: bool):
def gptoss_server(monkeypatch_module: pytest.MonkeyPatch, def gptoss_server(monkeypatch_module: pytest.MonkeyPatch,
default_server_args: list[str]): default_server_args: list[str]):
with monkeypatch_module.context() as m: with monkeypatch_module.context() as m:
m.setenv("VLLM_ATTENTION_BACKEND", "TRITON_ATTN_VLLM_V1") m.setenv("VLLM_ATTENTION_BACKEND", "TRITON_ATTN")
with RemoteOpenAIServer(GPT_OSS_MODEL_NAME, with RemoteOpenAIServer(GPT_OSS_MODEL_NAME,
default_server_args) as remote_server: default_server_args) as remote_server:
yield remote_server yield remote_server

View File

@ -31,7 +31,7 @@ DEVICE_MLA_BACKENDS = {
} }
DEVICE_REGULAR_ATTN_BACKENDS = { DEVICE_REGULAR_ATTN_BACKENDS = {
"cuda": ["XFORMERS", "FLASHINFER"], "cuda": ["XFORMERS", "FLASHINFER", "FLASH_ATTN"],
"hip": ["ROCM_FLASH"], "hip": ["ROCM_FLASH"],
"cpu": ["TORCH_SDPA"], "cpu": ["TORCH_SDPA"],
} }
@ -86,7 +86,7 @@ def test_env(
with patch("vllm.attention.selector.current_platform", with patch("vllm.attention.selector.current_platform",
CpuPlatform()): CpuPlatform()):
backend = get_attn_backend(16, torch.float16, None, block_size) backend = get_attn_backend(16, torch.float16, None, block_size)
assert backend.get_name() == "TORCH_SDPA_VLLM_V1" assert backend.get_name() == "TORCH_SDPA"
elif device == "hip": elif device == "hip":
with patch("vllm.attention.selector.current_platform", with patch("vllm.attention.selector.current_platform",
@ -125,7 +125,7 @@ def test_env(
None, None,
block_size, block_size,
use_mla=use_mla) use_mla=use_mla)
expected = f"{name}_VLLM_V1" expected = name
assert backend.get_name() == expected assert backend.get_name() == expected
else: else:
backend = get_attn_backend(16, backend = get_attn_backend(16,
@ -133,7 +133,7 @@ def test_env(
None, None,
block_size, block_size,
use_mla=use_mla) use_mla=use_mla)
expected = "TRITON_ATTN_VLLM_V1" expected = "TRITON_ATTN"
assert backend.get_name() == expected assert backend.get_name() == expected
elif device == "cuda": elif device == "cuda":
@ -160,7 +160,7 @@ def test_env(
None, None,
block_size, block_size,
use_mla=use_mla) use_mla=use_mla)
expected = "CUTLASS_MLA_VLLM_V1" expected = "CUTLASS_MLA"
assert backend.get_name() == expected assert backend.get_name() == expected
elif name == "FLASHINFER_MLA": elif name == "FLASHINFER_MLA":
if block_size not in [32, 64]: if block_size not in [32, 64]:
@ -193,7 +193,7 @@ def test_env(
None, None,
block_size, block_size,
use_mla=use_mla) use_mla=use_mla)
expected = f"{name}_VLLM_V1" expected = name
assert backend.get_name() == expected assert backend.get_name() == expected
elif name == "FLASH_ATTN_MLA": elif name == "FLASH_ATTN_MLA":
backend = get_attn_backend(16, backend = get_attn_backend(16,
@ -210,7 +210,7 @@ def test_env(
None, None,
block_size, block_size,
use_mla=use_mla) use_mla=use_mla)
expected = "TRITON_MLA_VLLM_V1" expected = "TRITON_MLA"
assert backend.get_name() == expected assert backend.get_name() == expected
elif name == "FLASHINFER": elif name == "FLASHINFER":
backend = get_attn_backend(16, backend = get_attn_backend(16,
@ -218,25 +218,24 @@ def test_env(
None, None,
block_size, block_size,
use_mla=use_mla) use_mla=use_mla)
expected = "FLASHINFER_VLLM_V1" expected = "FLASHINFER"
assert backend.get_name() == expected assert backend.get_name() == expected
else: elif name == "XFORMERS":
backend = get_attn_backend(32, backend = get_attn_backend(32,
torch.float16, torch.float16,
None, None,
block_size, block_size,
use_mla=use_mla) use_mla=use_mla)
expected = "FLASH_ATTN_VLLM_V1" expected = "XFORMERS"
assert backend.get_name() == expected assert backend.get_name() == expected
elif name == "FLASH_ATTN":
backend = get_attn_backend(16, backend = get_attn_backend(32,
torch.float16, torch.float16,
None, None,
block_size, block_size,
use_mla=use_mla) use_mla=use_mla)
assert backend.get_name() == "FLEX_ATTENTION", ( expected = "FLASH_ATTN"
"Should fallback to FlexAttention if head size is " assert backend.get_name() == expected
"not supported by FlashAttention")
@pytest.mark.parametrize("device", ["cpu", "cuda"]) @pytest.mark.parametrize("device", ["cpu", "cuda"])
@ -252,7 +251,7 @@ def test_fp32_fallback(
with patch("vllm.attention.selector.current_platform", with patch("vllm.attention.selector.current_platform",
CpuPlatform()): CpuPlatform()):
backend = get_attn_backend(16, torch.float32, None, 16) backend = get_attn_backend(16, torch.float32, None, 16)
assert backend.get_name() == "TORCH_SDPA_VLLM_V1" assert backend.get_name() == "TORCH_SDPA"
elif device == "cuda": elif device == "cuda":
with patch("vllm.attention.selector.current_platform", with patch("vllm.attention.selector.current_platform",
@ -266,6 +265,9 @@ def test_flash_attn(monkeypatch: pytest.MonkeyPatch):
# TODO: When testing for v1, pipe in `use_v1` as an argument to # TODO: When testing for v1, pipe in `use_v1` as an argument to
# get_attn_backend # get_attn_backend
pytest.skip("Skipping as current backend selector does not " \
"handle fallbacks when a backend is set via env var.")
with monkeypatch.context() as m: with monkeypatch.context() as m:
m.setenv(STR_BACKEND_ENV_VAR, STR_FLASH_ATTN_VAL) m.setenv(STR_BACKEND_ENV_VAR, STR_FLASH_ATTN_VAL)

View File

@ -28,7 +28,7 @@ def test_selector(monkeypatch: pytest.MonkeyPatch):
# Test standard ROCm attention # Test standard ROCm attention
backend = get_attn_backend(16, torch.float16, torch.float16, 16, False) backend = get_attn_backend(16, torch.float16, torch.float16, 16, False)
assert (backend.get_name() == "ROCM_FLASH" assert (backend.get_name() == "ROCM_FLASH"
or backend.get_name() == "TRITON_ATTN_VLLM_V1") or backend.get_name() == "TRITON_ATTN")
# MLA test for deepseek related # MLA test for deepseek related
@ -40,8 +40,7 @@ def test_selector(monkeypatch: pytest.MonkeyPatch):
16, 16,
False, False,
use_mla=True) use_mla=True)
assert (backend.get_name() == "TRITON_MLA" assert backend.get_name() == "TRITON_MLA"
or backend.get_name() == "TRITON_MLA_VLLM_V1")
# If attention backend is None # If attention backend is None
# If use_mla is true # If use_mla is true
@ -53,8 +52,7 @@ def test_selector(monkeypatch: pytest.MonkeyPatch):
16, 16,
False, False,
use_mla=True) use_mla=True)
assert (backend.get_name() == "TRITON_MLA" assert backend.get_name() == "TRITON_MLA"
or backend.get_name() == "TRITON_MLA_VLLM_V1")
# change the attention backend to AITER MLA # change the attention backend to AITER MLA
m.setenv(STR_BACKEND_ENV_VAR, "ROCM_AITER_MLA") m.setenv(STR_BACKEND_ENV_VAR, "ROCM_AITER_MLA")
@ -64,8 +62,7 @@ def test_selector(monkeypatch: pytest.MonkeyPatch):
1, 1,
False, False,
use_mla=True) use_mla=True)
assert (backend.get_name() == "ROCM_AITER_MLA" assert backend.get_name() == "ROCM_AITER_MLA"
or backend.get_name() == "ROCM_AITER_MLA_VLLM_V1")
# If attention backend is None # If attention backend is None
# If use_mla is true # If use_mla is true
@ -79,5 +76,4 @@ def test_selector(monkeypatch: pytest.MonkeyPatch):
1, 1,
False, False,
use_mla=True) use_mla=True)
assert (backend.get_name() == "ROCM_AITER_MLA" assert backend.get_name() == "ROCM_AITER_MLA"
or backend.get_name() == "ROCM_AITER_MLA_VLLM_V1")

View File

@ -524,14 +524,14 @@ def make_backend(backend_name: str) -> AttentionBackend:
* Backend instance * Backend instance
''' '''
if backend_name in (STR_XFORMERS_ATTN_VAL, "XFORMERS_VLLM_V1"): if backend_name == STR_XFORMERS_ATTN_VAL:
from vllm.v1.attention.backends.xformers import ( from vllm.v1.attention.backends.xformers import (
XFormersAttentionBackend) XFormersAttentionBackend)
return XFormersAttentionBackend() return XFormersAttentionBackend()
if backend_name in (STR_FLASH_ATTN_VAL, "FLASH_ATTN_VLLM_V1"): if backend_name == STR_FLASH_ATTN_VAL:
from vllm.v1.attention.backends.flash_attn import FlashAttentionBackend from vllm.v1.attention.backends.flash_attn import FlashAttentionBackend
return FlashAttentionBackend() return FlashAttentionBackend()
if backend_name == "TRITON_ATTN_VLLM_V1": if backend_name == "TRITON_ATTN":
from vllm.v1.attention.backends.triton_attn import ( from vllm.v1.attention.backends.triton_attn import (
TritonAttentionBackend) TritonAttentionBackend)
return TritonAttentionBackend() return TritonAttentionBackend()
@ -539,7 +539,7 @@ def make_backend(backend_name: str) -> AttentionBackend:
from vllm.v1.attention.backends.flex_attention import ( from vllm.v1.attention.backends.flex_attention import (
FlexAttentionBackend) FlexAttentionBackend)
return FlexAttentionBackend() return FlexAttentionBackend()
if backend_name in ("TORCH_SDPA", "TORCH_SDPA_VLLM_V1"): if backend_name == "TORCH_SDPA":
from vllm.v1.attention.backends.cpu_attn import TorchSDPABackend from vllm.v1.attention.backends.cpu_attn import TorchSDPABackend
return TorchSDPABackend() return TorchSDPABackend()
if backend_name == "FLASHINFER": if backend_name == "FLASHINFER":

View File

@ -84,7 +84,7 @@ def can_initialize(model_arch: str, monkeypatch: pytest.MonkeyPatch,
# FIXME: A hack to bypass FA3 assertion because our CI's L4 GPU # FIXME: A hack to bypass FA3 assertion because our CI's L4 GPU
# has cc==8.9 which hasn't supported FA3 yet. Remove this hack when # has cc==8.9 which hasn't supported FA3 yet. Remove this hack when
# L4 supports FA3. # L4 supports FA3.
m.setenv("VLLM_ATTENTION_BACKEND", "TRITON_ATTN_VLLM_V1") m.setenv("VLLM_ATTENTION_BACKEND", "TRITON_ATTN")
if model_arch == "WhisperForConditionalGeneration": if model_arch == "WhisperForConditionalGeneration":
m.setenv("VLLM_WORKER_MULTIPROC_METHOD", "spawn") m.setenv("VLLM_WORKER_MULTIPROC_METHOD", "spawn")
LLM( LLM(

View File

@ -1131,14 +1131,14 @@ def has_module_attribute(module_name, attribute_name):
def get_attn_backend_list_based_on_platform() -> list[str]: def get_attn_backend_list_based_on_platform() -> list[str]:
if current_platform.is_cuda(): if current_platform.is_cuda():
return ["FLASH_ATTN_VLLM_V1", "TRITON_ATTN_VLLM_V1", "TREE_ATTN"] return ["FLASH_ATTN", "TRITON_ATTN", "TREE_ATTN"]
elif current_platform.is_rocm(): elif current_platform.is_rocm():
attn_backend_list = ["TRITON_ATTN_VLLM_V1"] attn_backend_list = ["TRITON_ATTN"]
try: try:
import aiter # noqa: F401 import aiter # noqa: F401
attn_backend_list.append("FLASH_ATTN_VLLM_V1") attn_backend_list.append("FLASH_ATTN")
except Exception: except Exception:
print("Skip FLASH_ATTN_VLLM_V1 on ROCm as aiter is not installed") print("Skip FLASH_ATTN on ROCm as aiter is not installed")
return attn_backend_list return attn_backend_list
else: else:

View File

@ -21,16 +21,15 @@ from vllm.v1.attention.backends.utils import (CommonAttentionMetadata,
from vllm.v1.kv_cache_interface import FullAttentionSpec from vllm.v1.kv_cache_interface import FullAttentionSpec
BACKENDS_TO_TEST = [ BACKENDS_TO_TEST = [
_Backend.FLASH_ATTN_VLLM_V1, _Backend.FLASHINFER_VLLM_V1, _Backend.FLASH_ATTN, _Backend.FLASHINFER, _Backend.FLEX_ATTENTION,
_Backend.FLEX_ATTENTION, _Backend.TRITON_ATTN_VLLM_V1, _Backend.TREE_ATTN, _Backend.TRITON_ATTN, _Backend.TREE_ATTN, "FLEX_ATTENTION_SLOW"
"FLEX_ATTENTION_SLOW"
] ]
# Remove flashinfer from the list if it's not available # Remove flashinfer from the list if it's not available
try: try:
import flashinfer # noqa: F401 import flashinfer # noqa: F401
except ImportError: except ImportError:
BACKENDS_TO_TEST.remove(_Backend.FLASHINFER_VLLM_V1) BACKENDS_TO_TEST.remove(_Backend.FLASHINFER)
def _convert_dtype_to_torch(dtype): def _convert_dtype_to_torch(dtype):
@ -214,7 +213,7 @@ def run_attention_backend(
builder_cls, impl_cls = get_attention_backend(actual_backend) builder_cls, impl_cls = get_attention_backend(actual_backend)
# Mock flashinfer's get_per_layer_parameters if needed # Mock flashinfer's get_per_layer_parameters if needed
if actual_backend == _Backend.FLASHINFER_VLLM_V1: if actual_backend == _Backend.FLASHINFER:
import unittest.mock import unittest.mock
from vllm.v1.attention.backends.utils import PerLayerParameters from vllm.v1.attention.backends.utils import PerLayerParameters
@ -434,7 +433,7 @@ def _test_backend_correctness(
# [num_blocks, 2, block_size, num_kv_heads, head_size] # [num_blocks, 2, block_size, num_kv_heads, head_size]
# Select the appropriate KV cache format for each backend # Select the appropriate KV cache format for each backend
kv_cache_for_backend = kv_cache kv_cache_for_backend = kv_cache
if backend_name == _Backend.FLASHINFER_VLLM_V1: if backend_name == _Backend.FLASHINFER:
kv_cache_for_backend = kv_cache.transpose(0, 1) kv_cache_for_backend = kv_cache.transpose(0, 1)
# For FlashInfer default to HND layout and # For FlashInfer default to HND layout and
@ -518,8 +517,8 @@ def test_causal_backend_correctness(batch_spec_name: str, model: str):
SLIDING_WINDOW_BACKENDS_TO_TEST = [ SLIDING_WINDOW_BACKENDS_TO_TEST = [
_Backend.FLASH_ATTN_VLLM_V1, _Backend.FLEX_ATTENTION, _Backend.FLASH_ATTN, _Backend.FLEX_ATTENTION, _Backend.TRITON_ATTN,
_Backend.TRITON_ATTN_VLLM_V1, "FLEX_ATTENTION_SLOW" "FLEX_ATTENTION_SLOW"
] ]

View File

@ -15,8 +15,8 @@ from vllm.v1.attention.backends.utils import CommonAttentionMetadata
from vllm.v1.kv_cache_interface import FullAttentionSpec from vllm.v1.kv_cache_interface import FullAttentionSpec
BACKENDS_TO_TEST = [ BACKENDS_TO_TEST = [
_Backend.CUTLASS_MLA, _Backend.FLASHMLA_VLLM_V1, _Backend.FLASH_ATTN_MLA, _Backend.CUTLASS_MLA, _Backend.FLASHMLA, _Backend.FLASH_ATTN_MLA,
_Backend.TRITON_MLA_VLLM_V1 _Backend.TRITON_MLA
] ]
# Remove CUTLASS_MLA from the list if not using sm100 # Remove CUTLASS_MLA from the list if not using sm100

View File

@ -120,30 +120,30 @@ def get_attention_backend(backend_name: _Backend):
Tuple of (backend_builder_class, backend_impl_class) Tuple of (backend_builder_class, backend_impl_class)
""" """
backend_map = { backend_map = {
_Backend.FLASH_ATTN_VLLM_V1: _Backend.FLASH_ATTN:
("vllm.v1.attention.backends.flash_attn.FlashAttentionBackend" ("vllm.v1.attention.backends.flash_attn.FlashAttentionBackend"
if current_platform.is_cuda() else if current_platform.is_cuda() else
"vllm.v1.attention.backends.rocm_aiter_fa.AiterFlashAttentionBackend" "vllm.v1.attention.backends.rocm_aiter_fa.AiterFlashAttentionBackend"
), ),
_Backend.FLASHINFER_VLLM_V1: _Backend.FLASHINFER:
"vllm.v1.attention.backends.flashinfer.FlashInferBackend", "vllm.v1.attention.backends.flashinfer.FlashInferBackend",
_Backend.FLEX_ATTENTION: _Backend.FLEX_ATTENTION:
"vllm.v1.attention.backends.flex_attention.FlexAttentionBackend", "vllm.v1.attention.backends.flex_attention.FlexAttentionBackend",
_Backend.TRITON_ATTN_VLLM_V1: _Backend.TRITON_ATTN:
"vllm.v1.attention.backends.triton_attn.TritonAttentionBackend", "vllm.v1.attention.backends.triton_attn.TritonAttentionBackend",
_Backend.TREE_ATTN: _Backend.TREE_ATTN:
"vllm.v1.attention.backends.tree_attn.TreeAttentionBackend", "vllm.v1.attention.backends.tree_attn.TreeAttentionBackend",
_Backend.XFORMERS_VLLM_V1: _Backend.XFORMERS:
"vllm.v1.attention.backends.xformers.XFormersAttentionBackend", "vllm.v1.attention.backends.xformers.XFormersAttentionBackend",
_Backend.CUTLASS_MLA: _Backend.CUTLASS_MLA:
"vllm.v1.attention.backends.mla.cutlass_mla.CutlassMLABackend", "vllm.v1.attention.backends.mla.cutlass_mla.CutlassMLABackend",
_Backend.FLASHMLA_VLLM_V1: _Backend.FLASHMLA:
"vllm.v1.attention.backends.mla.flashmla.FlashMLABackend", "vllm.v1.attention.backends.mla.flashmla.FlashMLABackend",
_Backend.FLASH_ATTN_MLA: _Backend.FLASH_ATTN_MLA:
"vllm.v1.attention.backends.mla.flashattn_mla.FlashAttnMLABackend", "vllm.v1.attention.backends.mla.flashattn_mla.FlashAttnMLABackend",
_Backend.FLASHINFER_MLA: _Backend.FLASHINFER_MLA:
"vllm.v1.attention.backends.mla.flashinfer_mla.FlashInferMLABackend", "vllm.v1.attention.backends.mla.flashinfer_mla.FlashInferMLABackend",
_Backend.TRITON_MLA_VLLM_V1: _Backend.TRITON_MLA:
"vllm.v1.attention.backends.mla.triton_mla.TritonMLABackend", "vllm.v1.attention.backends.mla.triton_mla.TritonMLABackend",
} }

View File

@ -89,7 +89,7 @@ backend_configs = {
# Triton Attention # Triton Attention
"TritonAttn": "TritonAttn":
BackendConfig(name="TritonAttn", BackendConfig(name="TritonAttn",
env_vars={"VLLM_ATTENTION_BACKEND": "TRITON_ATTN_VLLM_V1"}, env_vars={"VLLM_ATTENTION_BACKEND": "TRITON_ATTN"},
comp_config={ comp_config={
"cudagraph_mode": "FULL_AND_PIECEWISE", "cudagraph_mode": "FULL_AND_PIECEWISE",
}), }),

View File

@ -9,11 +9,14 @@ from ...utils import create_new_process_for_each_test
@create_new_process_for_each_test() @create_new_process_for_each_test()
@pytest.mark.parametrize("attn_backend", @pytest.mark.parametrize("attn_backend", ["FLASH_ATTN", "FLASHINFER"])
["FLASH_ATTN_VLLM_V1", "FLASHINFER_VLLM_V1"])
def test_cascade_attention(example_system_message, monkeypatch, attn_backend): def test_cascade_attention(example_system_message, monkeypatch, attn_backend):
prompt = "\n<User>: Implement fibonacci sequence in Python.\n<Claude>:" prompt = "\n<User>: Implement fibonacci sequence in Python.\n<Claude>:"
if attn_backend == "FLASHINFER":
pytest.skip("This test is failing with FlashInfer backend and "
"needs investigation. See issue #25679.")
with monkeypatch.context() as m: with monkeypatch.context() as m:
m.setenv("VLLM_USE_V1", "1") m.setenv("VLLM_USE_V1", "1")
m.setenv("VLLM_ATTENTION_BACKEND", attn_backend) m.setenv("VLLM_ATTENTION_BACKEND", attn_backend)

View File

@ -176,12 +176,11 @@ def test_eagle_correctness(
m.setenv("VLLM_MLA_DISABLE", "1") m.setenv("VLLM_MLA_DISABLE", "1")
m.setenv("VLLM_ATTENTION_BACKEND", attn_backend) m.setenv("VLLM_ATTENTION_BACKEND", attn_backend)
if (attn_backend == "TRITON_ATTN_VLLM_V1" if (attn_backend == "TRITON_ATTN" and not current_platform.is_rocm()):
and not current_platform.is_rocm()): pytest.skip("TRITON_ATTN does not support "
pytest.skip("TRITON_ATTN_VLLM_V1 does not support "
"multi-token eagle spec decode on current platform") "multi-token eagle spec decode on current platform")
if attn_backend == "FLASH_ATTN_VLLM_V1" and current_platform.is_rocm(): if attn_backend == "FLASH_ATTN" and current_platform.is_rocm():
m.setenv("VLLM_ROCM_USE_AITER", "1") m.setenv("VLLM_ROCM_USE_AITER", "1")
method, model_name, spec_model_name, tp_size = model_setup method, model_name, spec_model_name, tp_size = model_setup

View File

@ -314,12 +314,11 @@ def test_load_model(mock_get_model, mock_get_layers, mock_get_pp_group, method,
monkeypatch.setenv("VLLM_ATTENTION_BACKEND", attn_backend) monkeypatch.setenv("VLLM_ATTENTION_BACKEND", attn_backend)
if (attn_backend == "TRITON_ATTN_VLLM_V1" if (attn_backend == "TRITON_ATTN" and not current_platform.is_rocm()):
and not current_platform.is_rocm()): pytest.skip("TRITON_ATTN does not support "
pytest.skip("TRITON_ATTN_VLLM_V1 does not support "
"multi-token eagle spec decode on current platform") "multi-token eagle spec decode on current platform")
if attn_backend == "FLASH_ATTN_VLLM_V1" and current_platform.is_rocm(): if attn_backend == "FLASH_ATTN" and current_platform.is_rocm():
monkeypatch.setenv("VLLM_ROCM_USE_AITER", "1") monkeypatch.setenv("VLLM_ROCM_USE_AITER", "1")
# Setup draft model mock # Setup draft model mock
@ -400,16 +399,15 @@ def test_propose(method, attn_backend, num_speculative_tokens, monkeypatch):
monkeypatch.setenv("VLLM_ATTENTION_BACKEND", attn_backend) monkeypatch.setenv("VLLM_ATTENTION_BACKEND", attn_backend)
if (attn_backend == "TRITON_ATTN_VLLM_V1" if (attn_backend == "TRITON_ATTN" and not current_platform.is_rocm()):
and not current_platform.is_rocm()): pytest.skip("TRITON_ATTN does not support "
pytest.skip("TRITON_ATTN_VLLM_V1 does not support "
"multi-token eagle spec decode on current platform") "multi-token eagle spec decode on current platform")
if (attn_backend == "TREE_ATTN"): if (attn_backend == "TREE_ATTN"):
pytest.skip("TREE_ATTN is tested separately in test_propose_tree" pytest.skip("TREE_ATTN is tested separately in test_propose_tree"
"because it requires special input mocking.") "because it requires special input mocking.")
if attn_backend == "FLASH_ATTN_VLLM_V1" and current_platform.is_rocm(): if attn_backend == "FLASH_ATTN" and current_platform.is_rocm():
monkeypatch.setenv("VLLM_ROCM_USE_AITER", "1") monkeypatch.setenv("VLLM_ROCM_USE_AITER", "1")
# Use GPU device # Use GPU device
@ -510,12 +508,12 @@ def test_propose(method, attn_backend, num_speculative_tokens, monkeypatch):
device=device) device=device)
sampling_metadata = mock.MagicMock() sampling_metadata = mock.MagicMock()
if attn_backend == "FLASH_ATTN_VLLM_V1": if attn_backend == "FLASH_ATTN":
attn_metadata_builder_cls, _ = get_attention_backend( attn_metadata_builder_cls, _ = get_attention_backend(
_Backend.FLASH_ATTN_VLLM_V1) _Backend.FLASH_ATTN)
elif attn_backend == "TRITON_ATTN_VLLM_V1": elif attn_backend == "TRITON_ATTN":
attn_metadata_builder_cls, _ = get_attention_backend( attn_metadata_builder_cls, _ = get_attention_backend(
_Backend.TRITON_ATTN_VLLM_V1) _Backend.TRITON_ATTN)
elif attn_backend == "TREE_ATTN": elif attn_backend == "TREE_ATTN":
attn_metadata_builder_cls, _ = get_attention_backend( attn_metadata_builder_cls, _ = get_attention_backend(
_Backend.TREE_ATTN) _Backend.TREE_ATTN)

View File

@ -41,12 +41,11 @@ def test_eagle_max_len(monkeypatch: pytest.MonkeyPatch,
m.setenv("VLLM_USE_V1", "1") m.setenv("VLLM_USE_V1", "1")
m.setenv("VLLM_ATTENTION_BACKEND", attn_backend) m.setenv("VLLM_ATTENTION_BACKEND", attn_backend)
if (attn_backend == "TRITON_ATTN_VLLM_V1" if (attn_backend == "TRITON_ATTN" and not current_platform.is_rocm()):
and not current_platform.is_rocm()): pytest.skip("TRITON_ATTN does not support "
pytest.skip("TRITON_ATTN_VLLM_V1 does not support "
"multi-token eagle spec decode on current platform") "multi-token eagle spec decode on current platform")
if attn_backend == "FLASH_ATTN_VLLM_V1" and current_platform.is_rocm(): if attn_backend == "FLASH_ATTN" and current_platform.is_rocm():
m.setenv("VLLM_ROCM_USE_AITER", "1") m.setenv("VLLM_ROCM_USE_AITER", "1")
llm = LLM( llm = LLM(

View File

@ -278,7 +278,7 @@ def test_tree_attn_correctness() -> None:
block_table=block_table, block_table=block_table,
slot_mapping=branch_slot_mapping, slot_mapping=branch_slot_mapping,
seqlen_k=sequence_position + q_len, seqlen_k=sequence_position + q_len,
backend=_Backend.FLASH_ATTN_VLLM_V1, backend=_Backend.FLASH_ATTN,
).view(batch_size, -1, num_heads, dim_per_head) ).view(batch_size, -1, num_heads, dim_per_head)
# Compare the outputs. # Compare the outputs.

View File

@ -54,26 +54,3 @@ def test_v1_llm_by_default(monkeypatch):
print(llm.generate("Hello my name is")) print(llm.generate("Hello my name is"))
assert hasattr(llm.llm_engine, "engine_core") assert hasattr(llm.llm_engine, "engine_core")
m.delenv("VLLM_USE_V1") m.delenv("VLLM_USE_V1")
def test_v1_attn_backend(monkeypatch):
with monkeypatch.context() as m:
if os.getenv("VLLM_USE_V1", None):
m.delenv("VLLM_USE_V1")
m.setenv("VLLM_ATTENTION_BACKEND", "XFORMERS")
# Fall back to V0.
_ = AsyncEngineArgs(model=MODEL).create_engine_config()
assert not envs.VLLM_USE_V1
m.delenv("VLLM_USE_V1")
# Reject if V1.
m.setenv("VLLM_USE_V1", "1")
with pytest.raises(NotImplementedError):
AsyncEngineArgs(model=MODEL).create_engine_config()
m.delenv("VLLM_USE_V1")
m.setenv("VLLM_ATTENTION_BACKEND", "FLASHMLA")
_ = AsyncEngineArgs(model=MODEL).create_engine_config()
assert envs.VLLM_USE_V1
m.delenv("VLLM_USE_V1")

View File

@ -364,7 +364,7 @@ class Attention(nn.Module, AttentionLayerBase):
self.impl.process_weights_after_loading(act_dtype) self.impl.process_weights_after_loading(act_dtype)
# FlashInfer requires attention sinks to be float32 # FlashInfer requires attention sinks to be float32
if (self.backend == _Backend.FLASHINFER_VLLM_V1 if (self.backend == _Backend.FLASHINFER
and hasattr(self.impl, 'sinks')): and hasattr(self.impl, 'sinks')):
from vllm.v1.attention.backends.flashinfer import FlashInferImpl from vllm.v1.attention.backends.flashinfer import FlashInferImpl
assert isinstance(self.impl, FlashInferImpl) assert isinstance(self.impl, FlashInferImpl)
@ -420,21 +420,17 @@ class MultiHeadAttention(nn.Module):
self.attn_backend = backend if backend in { self.attn_backend = backend if backend in {
_Backend.TORCH_SDPA, _Backend.TORCH_SDPA,
_Backend.TORCH_SDPA_VLLM_V1,
_Backend.XFORMERS, _Backend.XFORMERS,
_Backend.PALLAS_VLLM_V1, _Backend.PALLAS,
_Backend.ROCM_AITER_FA, _Backend.ROCM_AITER_FA,
_Backend.FLASH_ATTN, _Backend.FLASH_ATTN,
_Backend.FLASH_ATTN_VLLM_V1,
} else _Backend.TORCH_SDPA } else _Backend.TORCH_SDPA
if (self.attn_backend == _Backend.XFORMERS if (self.attn_backend == _Backend.XFORMERS
and not check_xformers_availability()): and not check_xformers_availability()):
self.attn_backend = _Backend.TORCH_SDPA self.attn_backend = _Backend.TORCH_SDPA
if self.attn_backend in { if self.attn_backend == _Backend.FLASH_ATTN:
_Backend.FLASH_ATTN, _Backend.FLASH_ATTN_VLLM_V1
}:
if use_upstream_fa: if use_upstream_fa:
from flash_attn import flash_attn_varlen_func from flash_attn import flash_attn_varlen_func
self._flash_attn_varlen_func = flash_attn_varlen_func self._flash_attn_varlen_func = flash_attn_varlen_func
@ -468,11 +464,7 @@ class MultiHeadAttention(nn.Module):
key = torch.repeat_interleave(key, num_repeat, dim=2) key = torch.repeat_interleave(key, num_repeat, dim=2)
value = torch.repeat_interleave(value, num_repeat, dim=2) value = torch.repeat_interleave(value, num_repeat, dim=2)
if self.attn_backend in { if self.attn_backend == _Backend.FLASH_ATTN:
_Backend.FLASH_ATTN,
_Backend.FLASH_ATTN_VLLM_V1,
}:
cu_seqlens_q = torch.arange(0, (bsz + 1) * q_len, cu_seqlens_q = torch.arange(0, (bsz + 1) * q_len,
step=q_len, step=q_len,
dtype=torch.int32, dtype=torch.int32,
@ -499,8 +491,7 @@ class MultiHeadAttention(nn.Module):
key, key,
value, value,
scale=self.scale) scale=self.scale)
elif (self.attn_backend == _Backend.TORCH_SDPA elif self.attn_backend == _Backend.TORCH_SDPA:
or self.attn_backend == _Backend.TORCH_SDPA_VLLM_V1):
query, key, value = (x.transpose(1, 2) query, key, value = (x.transpose(1, 2)
for x in (query, key, value)) for x in (query, key, value))
out = F.scaled_dot_product_attention(query, out = F.scaled_dot_product_attention(query,
@ -508,7 +499,7 @@ class MultiHeadAttention(nn.Module):
value, value,
scale=self.scale) scale=self.scale)
out = out.transpose(1, 2) out = out.transpose(1, 2)
elif self.attn_backend == _Backend.PALLAS_VLLM_V1: elif self.attn_backend == _Backend.PALLAS:
query, key, value = (x.transpose(1, 2) query, key, value = (x.transpose(1, 2)
for x in (query, key, value)) for x in (query, key, value))
from torch_xla.experimental.custom_kernel import flash_attention from torch_xla.experimental.custom_kernel import flash_attention

View File

@ -186,6 +186,14 @@ def _cached_get_attn_backend(
# Check the environment variable and override if specified # Check the environment variable and override if specified
backend_by_env_var: Optional[str] = envs.VLLM_ATTENTION_BACKEND backend_by_env_var: Optional[str] = envs.VLLM_ATTENTION_BACKEND
if backend_by_env_var is not None: if backend_by_env_var is not None:
if backend_by_env_var.endswith("_VLLM_V1"):
logger.warning(
"The suffix '_VLLM_V1' in the environment variable "
"%s is no longer necessary as V0 backends have been "
"deprecated. Please remove this suffix from your "
"environment variable setting.", STR_BACKEND_ENV_VAR)
backend_by_env_var = backend_by_env_var.removesuffix(
"_VLLM_V1")
selected_backend = backend_name_to_enum(backend_by_env_var) selected_backend = backend_name_to_enum(backend_by_env_var)
if selected_backend is None: if selected_backend is None:
raise ValueError( raise ValueError(

View File

@ -577,8 +577,8 @@ class NixlConnectorWorker:
use_mla=self.use_mla) use_mla=self.use_mla)
self.backend_name = backend.get_name() self.backend_name = backend.get_name()
attn_backend = backend_name_to_enum(self.backend_name) attn_backend = backend_name_to_enum(self.backend_name)
self._use_flashinfer = attn_backend == _Backend.FLASHINFER_VLLM_V1 self._use_flashinfer = attn_backend == _Backend.FLASHINFER
self._use_pallas_v1 = attn_backend == _Backend.PALLAS_VLLM_V1 self._use_pallas = attn_backend == _Backend.PALLAS
self.kv_cache_layout = get_kv_cache_layout() self.kv_cache_layout = get_kv_cache_layout()
logger.debug("Detected attention backend %s", self.backend_name) logger.debug("Detected attention backend %s", self.backend_name)
logger.debug("Detected kv cache layout %s", self.kv_cache_layout) logger.debug("Detected kv cache layout %s", self.kv_cache_layout)
@ -749,7 +749,7 @@ class NixlConnectorWorker:
# (roughly 8KB vs 5KB). # (roughly 8KB vs 5KB).
# Conversely for FlashInfer, K and V are registered in the same region # Conversely for FlashInfer, K and V are registered in the same region
# to better exploit the memory layout (ie num_blocks is the first dim). # to better exploit the memory layout (ie num_blocks is the first dim).
split_k_and_v = not (self.use_mla or self._use_pallas_v1 split_k_and_v = not (self.use_mla or self._use_pallas
or self._use_flashinfer) or self._use_flashinfer)
tensor_size_bytes = None tensor_size_bytes = None
for layer_name, cache_or_caches in xfer_buffers.items(): for layer_name, cache_or_caches in xfer_buffers.items():
@ -938,7 +938,7 @@ class NixlConnectorWorker:
tp_ratio = divide(self._tp_size[self.engine_id], tp_ratio = divide(self._tp_size[self.engine_id],
self._tp_size[engine_id]) self._tp_size[engine_id])
assert tp_ratio > 0, "Decode TP cannot be smaller than prefill TP" assert tp_ratio > 0, "Decode TP cannot be smaller than prefill TP"
assert not self._use_pallas_v1 or tp_ratio == 1, \ assert not self._use_pallas or tp_ratio == 1, \
"TPU (pallas_v1) DOES NOT support heterogeneous TP yet." "TPU (pallas_v1) DOES NOT support heterogeneous TP yet."
# Handle tp_size>num_kv_heads: replicate KV cache. # Handle tp_size>num_kv_heads: replicate KV cache.

View File

@ -1479,25 +1479,21 @@ class EngineArgs:
"such as ngram, medusa, eagle, or deepseek_mtp.") "such as ngram, medusa, eagle, or deepseek_mtp.")
V1_BACKENDS = [ V1_BACKENDS = [
"FLASH_ATTN_VLLM_V1",
"FLASH_ATTN", "FLASH_ATTN",
"PALLAS", "PALLAS",
"PALLAS_VLLM_V1", "TRITON_ATTN",
"TRITON_ATTN_VLLM_V1",
"TRITON_MLA", "TRITON_MLA",
"CUTLASS_MLA", "CUTLASS_MLA",
"FLASHMLA", "FLASHMLA",
"FLASHMLA_VLLM_V1",
"FLASH_ATTN_MLA", "FLASH_ATTN_MLA",
"FLASHINFER", "FLASHINFER",
"FLASHINFER_VLLM_V1",
"FLASHINFER_MLA", "FLASHINFER_MLA",
"ROCM_AITER_MLA", "ROCM_AITER_MLA",
"TORCH_SDPA_VLLM_V1", "TORCH_SDPA",
"FLEX_ATTENTION", "FLEX_ATTENTION",
"TREE_ATTN", "TREE_ATTN",
"XFORMERS_VLLM_V1", "XFORMERS",
"ROCM_ATTN_VLLM_V1", "ROCM_ATTN",
] ]
if (envs.is_set("VLLM_ATTENTION_BACKEND") if (envs.is_set("VLLM_ATTENTION_BACKEND")
and envs.VLLM_ATTENTION_BACKEND not in V1_BACKENDS): and envs.VLLM_ATTENTION_BACKEND not in V1_BACKENDS):

View File

@ -42,7 +42,7 @@ def kernel_warmup(worker: "Worker"):
# and is not a pooling model # and is not a pooling model
def _is_flashinfer_backend(backend): def _is_flashinfer_backend(backend):
try: try:
return backend.get_name() == "FLASHINFER_VLLM_V1" return backend.get_name() == "FLASHINFER"
except NotImplementedError: except NotImplementedError:
return False return False

View File

@ -241,9 +241,8 @@ class CudaPlatformBase(Platform):
use_flashinfermla = selected_backend == _Backend.FLASHINFER_MLA or ( use_flashinfermla = selected_backend == _Backend.FLASHINFER_MLA or (
selected_backend is None and cls.is_device_capability(100) selected_backend is None and cls.is_device_capability(100)
and block_size in [32, 64]) and block_size in [32, 64])
use_flashmla = selected_backend in [ use_flashmla = selected_backend == _Backend.FLASHMLA or (
_Backend.FLASHMLA, _Backend.FLASHMLA_VLLM_V1 selected_backend is None and is_flashmla_supported()[0])
] or (selected_backend is None and is_flashmla_supported()[0])
use_flashattn = selected_backend == _Backend.FLASH_ATTN_MLA or ( use_flashattn = selected_backend == _Backend.FLASH_ATTN_MLA or (
selected_backend is None and flash_attn_supports_mla()) selected_backend is None and flash_attn_supports_mla())
use_triton = selected_backend == _Backend.TRITON_MLA or ( use_triton = selected_backend == _Backend.TRITON_MLA or (
@ -282,7 +281,7 @@ class CudaPlatformBase(Platform):
if use_v1: if use_v1:
FLASHINFER_V1 = "vllm.v1.attention.backends.flashinfer.FlashInferBackend" # noqa: E501 FLASHINFER_V1 = "vllm.v1.attention.backends.flashinfer.FlashInferBackend" # noqa: E501
FLEX_ATTENTION_V1 = "vllm.v1.attention.backends.flex_attention.FlexAttentionBackend" # noqa: E501 FLEX_ATTENTION_V1 = "vllm.v1.attention.backends.flex_attention.FlexAttentionBackend" # noqa: E501
TRITON_ATTN_VLLM_V1 = "vllm.v1.attention.backends.triton_attn.TritonAttentionBackend" # noqa: E501 TRITON_ATTN = "vllm.v1.attention.backends.triton_attn.TritonAttentionBackend" # noqa: E501
FLASH_ATTN_V1 = "vllm.v1.attention.backends.flash_attn.FlashAttentionBackend" # noqa: E501 FLASH_ATTN_V1 = "vllm.v1.attention.backends.flash_attn.FlashAttentionBackend" # noqa: E501
TREE_ATTN_V1 = "vllm.v1.attention.backends.tree_attn.TreeAttentionBackend" # noqa: E501 TREE_ATTN_V1 = "vllm.v1.attention.backends.tree_attn.TreeAttentionBackend" # noqa: E501
XFORMERS_V1 = "vllm.v1.attention.backends.xformers.XFormersAttentionBackend" # noqa: E501 XFORMERS_V1 = "vllm.v1.attention.backends.xformers.XFormersAttentionBackend" # noqa: E501
@ -300,16 +299,16 @@ class CudaPlatformBase(Platform):
elif selected_backend == _Backend.FLEX_ATTENTION: elif selected_backend == _Backend.FLEX_ATTENTION:
logger.info_once("Using FlexAttention backend on V1 engine.") logger.info_once("Using FlexAttention backend on V1 engine.")
return FLEX_ATTENTION_V1 return FLEX_ATTENTION_V1
elif selected_backend == _Backend.TRITON_ATTN_VLLM_V1: elif selected_backend == _Backend.TRITON_ATTN:
logger.info_once("Using Triton backend on V1 engine.") logger.info_once("Using Triton backend on V1 engine.")
return TRITON_ATTN_VLLM_V1 return TRITON_ATTN
elif selected_backend == _Backend.FLASH_ATTN: elif selected_backend == _Backend.FLASH_ATTN:
logger.info_once("Using Flash Attention backend on V1 engine.") logger.info_once("Using Flash Attention backend on V1 engine.")
return FLASH_ATTN_V1 return FLASH_ATTN_V1
elif selected_backend == _Backend.TREE_ATTN: elif selected_backend == _Backend.TREE_ATTN:
logger.info_once("Using Tree Attention backend on V1 engine.") logger.info_once("Using Tree Attention backend on V1 engine.")
return TREE_ATTN_V1 return TREE_ATTN_V1
elif selected_backend == _Backend.XFORMERS_VLLM_V1: elif selected_backend == _Backend.XFORMERS:
logger.info_once("Using XFormers backend on V1 engine.") logger.info_once("Using XFormers backend on V1 engine.")
return XFORMERS_V1 return XFORMERS_V1
@ -341,7 +340,7 @@ class CudaPlatformBase(Platform):
if (has_sink or if (has_sink or
use_fp8_kv_cache) and not cls.is_device_capability(90): use_fp8_kv_cache) and not cls.is_device_capability(90):
logger.info_once("Using Triton backend on V1 engine.") logger.info_once("Using Triton backend on V1 engine.")
return TRITON_ATTN_VLLM_V1 return TRITON_ATTN
elif is_default_backend_supported := is_attn_backend_supported( elif is_default_backend_supported := is_attn_backend_supported(
FLASH_ATTN_V1, head_size, dtype, FLASH_ATTN_V1, head_size, dtype,
allow_import_error=False): allow_import_error=False):
@ -457,12 +456,12 @@ class CudaPlatformBase(Platform):
else: else:
# Default to FlashAttention # Default to FlashAttention
if attention_backend is None: if attention_backend is None:
attention_backend = "FLASH_ATTN_VLLM_V1" attention_backend = "FLASH_ATTN"
# All Blackwell backends support fp8 # All Blackwell backends support fp8
if cls.is_device_capability(100): if cls.is_device_capability(100):
supported = True supported = True
elif attention_backend == "FLASH_ATTN_VLLM_V1": elif attention_backend == "FLASH_ATTN":
if fp8_attention: if fp8_attention:
from vllm.attention.utils.fa_utils import ( from vllm.attention.utils.fa_utils import (
flash_attn_supports_fp8) flash_attn_supports_fp8)
@ -471,7 +470,7 @@ class CudaPlatformBase(Platform):
supported = True supported = True
elif attention_backend == "FLASHINFER": elif attention_backend == "FLASHINFER":
supported = True supported = True
elif attention_backend == "TRITON_ATTN_VLLM_V1": elif attention_backend == "TRITON_ATTN":
supported = cls.supports_fp8() supported = cls.supports_fp8()
return supported return supported

View File

@ -40,34 +40,26 @@ def in_wsl() -> bool:
class _Backend(enum.Enum): class _Backend(enum.Enum):
FLASH_ATTN = enum.auto() FLASH_ATTN = enum.auto()
FLASH_ATTN_VLLM_V1 = enum.auto() TRITON_ATTN = enum.auto()
TRITON_ATTN_VLLM_V1 = enum.auto()
XFORMERS = enum.auto() XFORMERS = enum.auto()
ROCM_FLASH = enum.auto() ROCM_FLASH = enum.auto()
ROCM_AITER_MLA = enum.auto() # Supported by V1 ROCM_AITER_MLA = enum.auto() # Supported by V1
ROCM_AITER_MLA_VLLM_V1 = enum.auto()
ROCM_AITER_FA = enum.auto() # used for ViT attn backend ROCM_AITER_FA = enum.auto() # used for ViT attn backend
TORCH_SDPA = enum.auto() TORCH_SDPA = enum.auto()
TORCH_SDPA_VLLM_V1 = enum.auto()
FLASHINFER = enum.auto() FLASHINFER = enum.auto()
FLASHINFER_VLLM_V1 = enum.auto()
FLASHINFER_MLA = enum.auto() FLASHINFER_MLA = enum.auto()
TRITON_MLA = enum.auto() # Supported by V1 TRITON_MLA = enum.auto() # Supported by V1
TRITON_MLA_VLLM_V1 = enum.auto()
CUTLASS_MLA = enum.auto() CUTLASS_MLA = enum.auto()
FLASHMLA = enum.auto() # Supported by V1 FLASHMLA = enum.auto() # Supported by V1
FLASHMLA_VLLM_V1 = enum.auto()
FLASH_ATTN_MLA = enum.auto() # Supported by V1 FLASH_ATTN_MLA = enum.auto() # Supported by V1
PALLAS = enum.auto() PALLAS = enum.auto()
PALLAS_VLLM_V1 = enum.auto()
IPEX = enum.auto() IPEX = enum.auto()
DUAL_CHUNK_FLASH_ATTN = enum.auto() DUAL_CHUNK_FLASH_ATTN = enum.auto()
DIFFERENTIAL_FLASH_ATTN = enum.auto() DIFFERENTIAL_FLASH_ATTN = enum.auto()
NO_ATTENTION = enum.auto() NO_ATTENTION = enum.auto()
FLEX_ATTENTION = enum.auto() FLEX_ATTENTION = enum.auto()
TREE_ATTN = enum.auto() TREE_ATTN = enum.auto()
XFORMERS_VLLM_V1 = enum.auto() ROCM_ATTN = enum.auto()
ROCM_ATTN_VLLM_V1 = enum.auto()
class PlatformEnum(enum.Enum): class PlatformEnum(enum.Enum):

View File

@ -218,8 +218,7 @@ class RocmPlatform(Platform):
raise ValueError( raise ValueError(
f" The selected backend, {selected_backend.name}," f" The selected backend, {selected_backend.name},"
f"does not support block size {block_size}.") f"does not support block size {block_size}.")
if selected_backend in (_Backend.ROCM_AITER_MLA, if selected_backend == _Backend.ROCM_AITER_MLA:
_Backend.ROCM_AITER_MLA_VLLM_V1):
if block_size == 1: if block_size == 1:
logger.info("Using AITER MLA backend on V1 engine.") logger.info("Using AITER MLA backend on V1 engine.")
return "vllm.v1.attention.backends.mla.rocm_aiter_mla.AiterMLABackend" # noqa: E501 return "vllm.v1.attention.backends.mla.rocm_aiter_mla.AiterMLABackend" # noqa: E501
@ -240,7 +239,7 @@ class RocmPlatform(Platform):
elif (envs.VLLM_ROCM_USE_AITER and elif (envs.VLLM_ROCM_USE_AITER and
envs.VLLM_USE_AITER_UNIFIED_ATTENTION) or \ envs.VLLM_USE_AITER_UNIFIED_ATTENTION) or \
envs.VLLM_V1_USE_PREFILL_DECODE_ATTENTION or \ envs.VLLM_V1_USE_PREFILL_DECODE_ATTENTION or \
selected_backend == _Backend.ROCM_ATTN_VLLM_V1: selected_backend == _Backend.ROCM_ATTN:
# rocm specific backend, with aiter and/or # rocm specific backend, with aiter and/or
# triton prefix-prefill # triton prefix-prefill
logger.info("Using Rocm/Aiter Attention backend on V1 engine.") logger.info("Using Rocm/Aiter Attention backend on V1 engine.")

View File

@ -50,8 +50,7 @@ class TpuPlatform(Platform):
dtype: torch.dtype, kv_cache_dtype: Optional[str], dtype: torch.dtype, kv_cache_dtype: Optional[str],
block_size: int, use_v1: bool, use_mla: bool, block_size: int, use_v1: bool, use_mla: bool,
has_sink) -> str: has_sink) -> str:
if (selected_backend != _Backend.PALLAS if selected_backend != _Backend.PALLAS:
and selected_backend != _Backend.PALLAS_VLLM_V1):
logger.info("Cannot use %s backend on TPU.", selected_backend) logger.info("Cannot use %s backend on TPU.", selected_backend)
if not use_v1: if not use_v1:

View File

@ -40,14 +40,14 @@ class XPUPlatform(Platform):
use_v1 = envs.VLLM_USE_V1 use_v1 = envs.VLLM_USE_V1
if not use_v1: if not use_v1:
raise ValueError("XPU backend only supports V1.") raise ValueError("XPU backend only supports V1.")
TRITON_ATTN_VLLM_V1 = "vllm.v1.attention.backends.triton_attn.TritonAttentionBackend" # noqa: E501 TRITON_ATTN = "vllm.v1.attention.backends.triton_attn.TritonAttentionBackend" # noqa: E501
FLASH_ATTN_V1 = "vllm.v1.attention.backends.flash_attn.FlashAttentionBackend" # noqa: E501 FLASH_ATTN = "vllm.v1.attention.backends.flash_attn.FlashAttentionBackend" # noqa: E501
if selected_backend == _Backend.TRITON_ATTN_VLLM_V1: if selected_backend == _Backend.TRITON_ATTN:
logger.info_once("Using Triton backend on V1 engine.") logger.info_once("Using Triton backend on V1 engine.")
return TRITON_ATTN_VLLM_V1 return TRITON_ATTN
elif selected_backend == _Backend.FLASH_ATTN: elif selected_backend == _Backend.FLASH_ATTN:
logger.info_once("Using Flash Attention backend on V1 engine.") logger.info_once("Using Flash Attention backend on V1 engine.")
return FLASH_ATTN_V1 return FLASH_ATTN
elif selected_backend: elif selected_backend:
raise ValueError( raise ValueError(
f"Invalid attention backend for {cls.device_name}, " f"Invalid attention backend for {cls.device_name}, "
@ -64,7 +64,7 @@ class XPUPlatform(Platform):
XPU only support fp8 kv cache with triton backend. XPU only support fp8 kv cache with triton backend.
""" """
if envs.is_set("VLLM_ATTENTION_BACKEND") and \ if envs.is_set("VLLM_ATTENTION_BACKEND") and \
envs.VLLM_ATTENTION_BACKEND == "TRITON_ATTN_VLLM_V1": envs.VLLM_ATTENTION_BACKEND == "TRITON_ATTN":
return kv_cache_dtype in ["fp8_e4m3", "fp8_e5m2", "fp8"] return kv_cache_dtype in ["fp8_e4m3", "fp8_e5m2", "fp8"]
return False return False

View File

@ -54,7 +54,7 @@ class TorchSDPABackend(AttentionBackend):
@staticmethod @staticmethod
def get_name() -> str: def get_name() -> str:
return "TORCH_SDPA_VLLM_V1" return "TORCH_SDPA"
@staticmethod @staticmethod
def get_impl_cls() -> type["TorchSDPABackendImpl"]: def get_impl_cls() -> type["TorchSDPABackendImpl"]:

View File

@ -60,7 +60,7 @@ class FlashAttentionBackend(AttentionBackend):
@staticmethod @staticmethod
def get_name() -> str: def get_name() -> str:
return "FLASH_ATTN_VLLM_V1" return "FLASH_ATTN"
@staticmethod @staticmethod
def get_impl_cls() -> type["FlashAttentionImpl"]: def get_impl_cls() -> type["FlashAttentionImpl"]:

View File

@ -167,7 +167,7 @@ class FlashInferBackend(AttentionBackend):
@staticmethod @staticmethod
def get_name() -> str: def get_name() -> str:
return "FLASHINFER_VLLM_V1" return "FLASHINFER"
@staticmethod @staticmethod
def get_impl_cls() -> type[FlashInferImpl]: def get_impl_cls() -> type[FlashInferImpl]:

View File

@ -270,7 +270,7 @@ class MLACommonBackend(AttentionBackend):
@staticmethod @staticmethod
def get_name() -> str: def get_name() -> str:
return "TRITON_MLA_VLLM_V1" return "TRITON_MLA"
@staticmethod @staticmethod
def get_metadata_cls() -> type["AttentionMetadata"]: def get_metadata_cls() -> type["AttentionMetadata"]:

View File

@ -27,7 +27,7 @@ class FlashMLABackend(MLACommonBackend):
@staticmethod @staticmethod
def get_name() -> str: def get_name() -> str:
return "FLASHMLA_VLLM_V1" return "FLASHMLA"
@staticmethod @staticmethod
def get_metadata_cls() -> type["FlashMLAMetadata"]: def get_metadata_cls() -> type["FlashMLAMetadata"]:

View File

@ -33,7 +33,7 @@ class AiterMLABackend(MLACommonBackend):
@staticmethod @staticmethod
def get_name() -> str: def get_name() -> str:
return "ROCM_AITER_MLA_VLLM_V1" return "ROCM_AITER_MLA"
@staticmethod @staticmethod
def get_impl_cls() -> type["AiterMLAImpl"]: def get_impl_cls() -> type["AiterMLAImpl"]:

View File

@ -24,7 +24,7 @@ class TritonMLABackend(MLACommonBackend):
@staticmethod @staticmethod
def get_name() -> str: def get_name() -> str:
return "TRITON_MLA_VLLM_V1" return "TRITON_MLA"
@staticmethod @staticmethod
def get_impl_cls() -> type["TritonMLAImpl"]: def get_impl_cls() -> type["TritonMLAImpl"]:

View File

@ -86,7 +86,7 @@ class PallasAttentionBackend(AttentionBackend):
@staticmethod @staticmethod
def get_name() -> str: def get_name() -> str:
return "PALLAS_VLLM_V1" return "PALLAS"
@staticmethod @staticmethod
def get_impl_cls() -> type["PallasAttentionBackendImpl"]: def get_impl_cls() -> type["PallasAttentionBackendImpl"]:

View File

@ -340,7 +340,7 @@ class AiterFlashAttentionBackend(AttentionBackend):
@staticmethod @staticmethod
def get_name() -> str: def get_name() -> str:
return "FLASH_ATTN_VLLM_V1" return "FLASH_ATTN"
@staticmethod @staticmethod
def get_impl_cls() -> type["AiterFlashAttentionImpl"]: def get_impl_cls() -> type["AiterFlashAttentionImpl"]:

View File

@ -159,7 +159,7 @@ class RocmAttentionBackend(AttentionBackend):
@staticmethod @staticmethod
def get_name() -> str: def get_name() -> str:
return "ROCM_ATTN_VLLM_V1" return "ROCM_ATTN"
@staticmethod @staticmethod
def get_impl_cls() -> type["RocmAttentionImpl"]: def get_impl_cls() -> type["RocmAttentionImpl"]:

View File

@ -52,7 +52,7 @@ class TreeAttentionBackend(AttentionBackend):
@staticmethod @staticmethod
def get_name() -> str: def get_name() -> str:
return "TREE_ATTN_VLLM_V1" return "TREE_ATTN"
@staticmethod @staticmethod
def get_impl_cls() -> type["TreeAttentionImpl"]: def get_impl_cls() -> type["TreeAttentionImpl"]:

View File

@ -155,7 +155,7 @@ class TritonAttentionBackend(AttentionBackend):
@staticmethod @staticmethod
def get_name() -> str: def get_name() -> str:
return "TRITON_ATTN_VLLM_V1" return "TRITON_ATTN"
@staticmethod @staticmethod
def get_impl_cls() -> type["TritonAttentionImpl"]: def get_impl_cls() -> type["TritonAttentionImpl"]:

View File

@ -90,7 +90,7 @@ class XFormersAttentionBackend(AttentionBackend):
@staticmethod @staticmethod
def get_name() -> str: def get_name() -> str:
return "XFORMERS_VLLM_V1" return "XFORMERS"
@staticmethod @staticmethod
def get_impl_cls() -> type["XFormersAttentionImpl"]: def get_impl_cls() -> type["XFormersAttentionImpl"]: