[Attention] Refactor CUDA attention backend selection logic (#24794)

Signed-off-by: Matthew Bonanni <mbonanni@redhat.com>
Signed-off-by: Matthew Bonanni <mbonanni001@gmail.com>
Co-authored-by: Luka Govedič <ProExpertProg@users.noreply.github.com>
This commit is contained in:
Matthew Bonanni 2025-11-11 06:40:44 -06:00 committed by GitHub
parent 2e78150d24
commit b30dfa03c5
No known key found for this signature in database
GPG Key ID: B5690EEEBB952194
61 changed files with 1338 additions and 1002 deletions

View File

@ -890,11 +890,16 @@ steps:
- vllm/model_executor/layers/fused_moe/flashinfer_cutlass_prepare_finalize.py - vllm/model_executor/layers/fused_moe/flashinfer_cutlass_prepare_finalize.py
- vllm/model_executor/layers/quantization/utils/flashinfer_utils.py - vllm/model_executor/layers/quantization/utils/flashinfer_utils.py
- vllm/v1/attention/backends/flashinfer.py - vllm/v1/attention/backends/flashinfer.py
- vllm/v1/attention/backends/mla/cutlass_mla.py
- vllm/v1/attention/backends/mla/flashinfer_mla.py
- vllm/platforms/cuda.py
- vllm/attention/selector.py
commands: commands:
- nvidia-smi - nvidia-smi
- python3 examples/offline_inference/basic/chat.py - python3 examples/offline_inference/basic/chat.py
# Attention # Attention
# num_heads2 broken by https://github.com/flashinfer-ai/flashinfer/issues/1353 # num_heads2 broken by https://github.com/flashinfer-ai/flashinfer/issues/1353
- pytest -v -s tests/kernels/attention/test_attention_selector.py
- pytest -v -s tests/kernels/attention/test_flashinfer.py -k 'not num_heads2' - pytest -v -s tests/kernels/attention/test_flashinfer.py -k 'not num_heads2'
- pytest -v -s tests/kernels/attention/test_flashinfer_trtllm_attention.py - pytest -v -s tests/kernels/attention/test_flashinfer_trtllm_attention.py
- pytest -v -s tests/kernels/attention/test_cutlass_mla_decode.py - pytest -v -s tests/kernels/attention/test_cutlass_mla_decode.py

View File

@ -10,7 +10,7 @@ from tests.utils import flat_product
from tests.v1.attention.utils import BatchSpec, create_common_attn_metadata from tests.v1.attention.utils import BatchSpec, create_common_attn_metadata
from vllm._custom_ops import cutlass_scaled_fp4_mm, scaled_fp4_quant from vllm._custom_ops import cutlass_scaled_fp4_mm, scaled_fp4_quant
from vllm.attention import Attention, AttentionMetadata from vllm.attention import Attention, AttentionMetadata
from vllm.attention.backends.registry import _Backend from vllm.attention.backends.registry import AttentionBackendEnum
from vllm.attention.selector import global_force_attn_backend_context_manager from vllm.attention.selector import global_force_attn_backend_context_manager
from vllm.compilation.fusion_attn import ATTN_OP, AttnFusionPass from vllm.compilation.fusion_attn import ATTN_OP, AttnFusionPass
from vllm.compilation.fx_utils import find_op_nodes from vllm.compilation.fx_utils import find_op_nodes
@ -104,7 +104,7 @@ class AttentionQuantPatternModel(torch.nn.Module):
# TODO(luka) use get_kv_cache_stride_order # TODO(luka) use get_kv_cache_stride_order
# Create dummy KV cache for the selected backend # Create dummy KV cache for the selected backend
if backend == _Backend.ROCM_ATTN: if backend == AttentionBackendEnum.ROCM_ATTN:
# k/v as 1st dimention # k/v as 1st dimention
# HND: [num_blocks, num_kv_heads, block_size, head_size] # HND: [num_blocks, num_kv_heads, block_size, head_size]
kv_cache = torch.zeros( kv_cache = torch.zeros(
@ -116,7 +116,7 @@ class AttentionQuantPatternModel(torch.nn.Module):
dtype=self.kv_cache_dtype, dtype=self.kv_cache_dtype,
device=self.device, device=self.device,
) )
elif backend == _Backend.ROCM_AITER_UNIFIED_ATTN: elif backend == AttentionBackendEnum.ROCM_AITER_UNIFIED_ATTN:
# k/v as 1st dimention # k/v as 1st dimention
# NHD: [num_blocks, block_size, num_kv_heads, head_size] # NHD: [num_blocks, block_size, num_kv_heads, head_size]
kv_cache = torch.zeros( kv_cache = torch.zeros(
@ -128,7 +128,7 @@ class AttentionQuantPatternModel(torch.nn.Module):
dtype=self.kv_cache_dtype, dtype=self.kv_cache_dtype,
device=self.device, device=self.device,
) )
elif backend == _Backend.TRITON_ATTN: elif backend == AttentionBackendEnum.TRITON_ATTN:
# k/v as 2nd dimention # k/v as 2nd dimention
# NHD: [num_blocks, block_size, num_kv_heads, head_size] # NHD: [num_blocks, block_size, num_kv_heads, head_size]
kv_cache = torch.zeros( kv_cache = torch.zeros(
@ -140,7 +140,7 @@ class AttentionQuantPatternModel(torch.nn.Module):
dtype=self.kv_cache_dtype, dtype=self.kv_cache_dtype,
device=self.device, device=self.device,
) )
elif backend == _Backend.FLASHINFER: elif backend == AttentionBackendEnum.FLASHINFER:
kv_cache = torch.zeros( kv_cache = torch.zeros(
num_blocks, num_blocks,
2, 2,
@ -244,8 +244,8 @@ MODELS_FP8: list[tuple[str, type]] = []
MODELS_FP4: list[tuple[str, type]] = [] MODELS_FP4: list[tuple[str, type]] = []
HEADS: list[tuple[int, int]] = [] HEADS: list[tuple[int, int]] = []
SPLIT_ATTENTION: list[bool] = [] SPLIT_ATTENTION: list[bool] = []
BACKENDS_FP8: list[_Backend] = [] BACKENDS_FP8: list[AttentionBackendEnum] = []
BACKENDS_FP4: list[_Backend] = [] BACKENDS_FP4: list[AttentionBackendEnum] = []
if current_platform.is_cuda(): if current_platform.is_cuda():
HEADS = [(64, 8), (40, 8)] HEADS = [(64, 8), (40, 8)]
@ -261,8 +261,8 @@ if current_platform.is_cuda():
TestAttentionNvfp4QuantPatternModel, TestAttentionNvfp4QuantPatternModel,
) )
] ]
BACKENDS_FP8 = [_Backend.TRITON_ATTN, _Backend.FLASHINFER] BACKENDS_FP8 = [AttentionBackendEnum.TRITON_ATTN, AttentionBackendEnum.FLASHINFER]
BACKENDS_FP4 = [_Backend.FLASHINFER] BACKENDS_FP4 = [AttentionBackendEnum.FLASHINFER]
elif current_platform.is_rocm(): elif current_platform.is_rocm():
HEADS = [(32, 8), (40, 8)] HEADS = [(32, 8), (40, 8)]
@ -270,9 +270,9 @@ elif current_platform.is_rocm():
("amd/Llama-3.1-8B-Instruct-FP8-KV", TestAttentionFp8StaticQuantPatternModel) ("amd/Llama-3.1-8B-Instruct-FP8-KV", TestAttentionFp8StaticQuantPatternModel)
] ]
BACKENDS = [ BACKENDS = [
_Backend.ROCM_AITER_UNIFIED_ATTN, AttentionBackendEnum.ROCM_AITER_UNIFIED_ATTN,
_Backend.ROCM_ATTN, AttentionBackendEnum.ROCM_ATTN,
_Backend.TRITON_ATTN, AttentionBackendEnum.TRITON_ATTN,
] ]
@ -302,11 +302,11 @@ def test_attention_quant_pattern(
custom_ops: str, custom_ops: str,
model_name: str, model_name: str,
model_class: type[AttentionQuantPatternModel], model_class: type[AttentionQuantPatternModel],
backend: _Backend, backend: AttentionBackendEnum,
dist_init, dist_init,
): ):
"""Test AttentionStaticQuantPattern fusion pass""" """Test AttentionStaticQuantPattern fusion pass"""
if backend == _Backend.FLASHINFER and ( if backend == AttentionBackendEnum.FLASHINFER and (
not current_platform.is_device_capability((10, 0)) or not has_flashinfer() not current_platform.is_device_capability((10, 0)) or not has_flashinfer()
): ):
pytest.skip("FlashInfer attn fusion requires Blackwell and flashinfer") pytest.skip("FlashInfer attn fusion requires Blackwell and flashinfer")
@ -314,6 +314,7 @@ def test_attention_quant_pattern(
custom_ops_list = custom_ops.split(",") if custom_ops else [] custom_ops_list = custom_ops.split(",") if custom_ops else []
device = torch.device("cuda:0") device = torch.device("cuda:0")
torch.set_default_dtype(dtype)
torch.manual_seed(42) torch.manual_seed(42)
vllm_config = VllmConfig( vllm_config = VllmConfig(
@ -402,7 +403,7 @@ def test_attention_quant_pattern(
result_fused_1 = model_compiled(q, k, v) result_fused_1 = model_compiled(q, k, v)
if backend == _Backend.FLASHINFER: if backend == AttentionBackendEnum.FLASHINFER:
# With the Flashinfer backend after the 1st round of the forward # With the Flashinfer backend after the 1st round of the forward
# pass, output quant scale should be loaded into the attn layer's # pass, output quant scale should be loaded into the attn layer's
# _o_scale_float, the 2nd round should reuse the loaded # _o_scale_float, the 2nd round should reuse the loaded

View File

@ -11,7 +11,7 @@ from typing import Any, NamedTuple
import pytest import pytest
import regex as re import regex as re
from tests.v1.attention.utils import _Backend from tests.v1.attention.utils import AttentionBackendEnum
from vllm import LLM, SamplingParams from vllm import LLM, SamplingParams
from vllm.config import CompilationConfig, CompilationMode, CUDAGraphMode, PassConfig from vllm.config import CompilationConfig, CompilationMode, CUDAGraphMode, PassConfig
from vllm.platforms import current_platform from vllm.platforms import current_platform
@ -24,7 +24,7 @@ from ..utils import flat_product, multi_gpu_test
class ModelBackendTestCase(NamedTuple): class ModelBackendTestCase(NamedTuple):
model_name: str model_name: str
model_kwargs: dict[str, Any] model_kwargs: dict[str, Any]
backend: _Backend backend: AttentionBackendEnum
attention_fusions: int attention_fusions: int
allreduce_fusions: int | None = None allreduce_fusions: int | None = None
@ -39,14 +39,14 @@ if current_platform.is_cuda():
# Use smaller model for L40s in CI # Use smaller model for L40s in CI
model_name="RedHatAI/Meta-Llama-3.1-8B-Instruct-FP8", model_name="RedHatAI/Meta-Llama-3.1-8B-Instruct-FP8",
model_kwargs=dict(max_model_len=1024), model_kwargs=dict(max_model_len=1024),
backend=_Backend.TRITON_ATTN, backend=AttentionBackendEnum.TRITON_ATTN,
attention_fusions=32, attention_fusions=32,
allreduce_fusions=65, allreduce_fusions=65,
), ),
ModelBackendTestCase( ModelBackendTestCase(
model_name="nvidia/Llama-4-Scout-17B-16E-Instruct-FP8", model_name="nvidia/Llama-4-Scout-17B-16E-Instruct-FP8",
model_kwargs=dict(max_model_len=1024, kv_cache_dtype="fp8"), model_kwargs=dict(max_model_len=1024, kv_cache_dtype="fp8"),
backend=_Backend.FLASHINFER, backend=AttentionBackendEnum.FLASHINFER,
attention_fusions=48, attention_fusions=48,
allreduce_fusions=96, allreduce_fusions=96,
), ),
@ -56,7 +56,7 @@ if current_platform.is_cuda():
ModelBackendTestCase( ModelBackendTestCase(
model_name="nvidia/Llama-3.1-8B-Instruct-FP4", model_name="nvidia/Llama-3.1-8B-Instruct-FP4",
model_kwargs=dict(max_model_len=1024, kv_cache_dtype="fp8"), model_kwargs=dict(max_model_len=1024, kv_cache_dtype="fp8"),
backend=_Backend.FLASHINFER, backend=AttentionBackendEnum.FLASHINFER,
attention_fusions=32, attention_fusions=32,
allreduce_fusions=65, allreduce_fusions=65,
), ),
@ -67,7 +67,7 @@ if current_platform.is_cuda():
ModelBackendTestCase( ModelBackendTestCase(
model_name="meta-llama/Llama-3.1-8B-Instruct", model_name="meta-llama/Llama-3.1-8B-Instruct",
model_kwargs=dict(max_model_len=1024), model_kwargs=dict(max_model_len=1024),
backend=_Backend.TRITON_ATTN, backend=AttentionBackendEnum.TRITON_ATTN,
attention_fusions=0, attention_fusions=0,
allreduce_fusions=65, allreduce_fusions=65,
), ),
@ -85,19 +85,19 @@ elif current_platform.is_rocm():
ModelBackendTestCase( ModelBackendTestCase(
model_name="amd/Llama-3.1-8B-Instruct-FP8-KV", model_name="amd/Llama-3.1-8B-Instruct-FP8-KV",
model_kwargs=dict(max_model_len=1024), model_kwargs=dict(max_model_len=1024),
backend=_Backend.TRITON_ATTN, backend=AttentionBackendEnum.TRITON_ATTN,
attention_fusions=32, attention_fusions=32,
), ),
ModelBackendTestCase( ModelBackendTestCase(
model_name="amd/Llama-3.1-8B-Instruct-FP8-KV", model_name="amd/Llama-3.1-8B-Instruct-FP8-KV",
model_kwargs=dict(max_model_len=1024), model_kwargs=dict(max_model_len=1024),
backend=_Backend.ROCM_ATTN, backend=AttentionBackendEnum.ROCM_ATTN,
attention_fusions=32, attention_fusions=32,
), ),
ModelBackendTestCase( ModelBackendTestCase(
model_name="amd/Llama-3.1-8B-Instruct-FP8-KV", model_name="amd/Llama-3.1-8B-Instruct-FP8-KV",
model_kwargs=dict(max_model_len=1024), model_kwargs=dict(max_model_len=1024),
backend=_Backend.ROCM_AITER_UNIFIED_ATTN, backend=AttentionBackendEnum.ROCM_AITER_UNIFIED_ATTN,
attention_fusions=32, attention_fusions=32,
), ),
] ]
@ -117,7 +117,7 @@ CUSTOM_OPS_FP8 = ["-quant_fp8", "+quant_fp8"]
def test_attn_quant( def test_attn_quant(
model_name: str, model_name: str,
model_kwargs: dict[str, Any], model_kwargs: dict[str, Any],
backend: _Backend, backend: AttentionBackendEnum,
attention_fusions: int, attention_fusions: int,
allreduce_fusions: int, allreduce_fusions: int,
custom_ops: str, custom_ops: str,
@ -125,7 +125,7 @@ def test_attn_quant(
caplog_mp_spawn, caplog_mp_spawn,
monkeypatch, monkeypatch,
): ):
if backend == _Backend.FLASHINFER and ( if backend == AttentionBackendEnum.FLASHINFER and (
not current_platform.is_device_capability((10, 0)) or not has_flashinfer() not current_platform.is_device_capability((10, 0)) or not has_flashinfer()
): ):
pytest.skip("FlashInfer attn fusion requires Blackwell and flashinfer") pytest.skip("FlashInfer attn fusion requires Blackwell and flashinfer")
@ -208,7 +208,7 @@ def custom_ops_product(*custom_ops_lists: list[str]) -> Iterable[str]:
def test_tp2_attn_quant_allreduce_rmsnorm( def test_tp2_attn_quant_allreduce_rmsnorm(
model_name: str, model_name: str,
model_kwargs: dict, model_kwargs: dict,
backend: _Backend, backend: AttentionBackendEnum,
attention_fusions: int, attention_fusions: int,
allreduce_fusions: int, allreduce_fusions: int,
custom_ops: str, custom_ops: str,

View File

@ -3,13 +3,13 @@
import pytest import pytest
from vllm.attention.backends.registry import _Backend from vllm.attention.backends.registry import AttentionBackendEnum
from vllm.config.multimodal import MultiModalConfig from vllm.config.multimodal import MultiModalConfig
def test_mm_encoder_attn_backend_str_conversion(): def test_mm_encoder_attn_backend_str_conversion():
config = MultiModalConfig(mm_encoder_attn_backend="FLASH_ATTN") config = MultiModalConfig(mm_encoder_attn_backend="FLASH_ATTN")
assert config.mm_encoder_attn_backend == _Backend.FLASH_ATTN assert config.mm_encoder_attn_backend == AttentionBackendEnum.FLASH_ATTN
def test_mm_encoder_attn_backend_invalid(): def test_mm_encoder_attn_backend_invalid():
@ -20,6 +20,6 @@ def test_mm_encoder_attn_backend_invalid():
def test_mm_encoder_attn_backend_hash_updates(): def test_mm_encoder_attn_backend_hash_updates():
base_hash = MultiModalConfig().compute_hash() base_hash = MultiModalConfig().compute_hash()
overridden_hash = MultiModalConfig( overridden_hash = MultiModalConfig(
mm_encoder_attn_backend=_Backend.FLASH_ATTN mm_encoder_attn_backend=AttentionBackendEnum.FLASH_ATTN
).compute_hash() ).compute_hash()
assert base_hash != overridden_hash assert base_hash != overridden_hash

View File

@ -120,12 +120,13 @@ def test_env(
elif device == "cuda": elif device == "cuda":
with patch("vllm.platforms.current_platform", CudaPlatform()): with patch("vllm.platforms.current_platform", CudaPlatform()):
capability = torch.cuda.get_device_capability()
if use_mla: if use_mla:
# CUDA MLA backend logic: # CUDA MLA backend logic:
# - CUTLASS_MLA: only supported with block_size == 128 # - CUTLASS_MLA: only supported with block_size == 128
# and Blackwell GPUs (SM 10.0), V1 only # and Blackwell GPUs (SM 10.x), V1 only
# - FLASHINFER_MLA: only supported on Blackwell GPUs # - FLASHINFER_MLA: only supported on Blackwell GPUs
# (SM 10.0+), V1 only # (SM 10.x), V1 only
# - FLASHMLA: only supported with block_size == 64 # - FLASHMLA: only supported with block_size == 64
# - FLASH_ATTN_MLA: V1 only # - FLASH_ATTN_MLA: V1 only
# - TRITON_MLA: fallback for other cases # - TRITON_MLA: fallback for other cases
@ -134,58 +135,72 @@ def test_env(
if block_size != 128: if block_size != 128:
# CUTLASS_MLA only supports block_size == 128 # CUTLASS_MLA only supports block_size == 128
pytest.skip("CUTLASS_MLA only supports block_size 128") pytest.skip("CUTLASS_MLA only supports block_size 128")
else: if capability[0] != 10:
backend = get_attn_backend( pytest.skip("CUTLASS MLA is not supported on this platform")
16, torch.float16, None, block_size, use_mla=use_mla backend = get_attn_backend(
) 576, torch.float16, None, block_size, use_mla=use_mla
expected = "CUTLASS_MLA" )
assert backend.get_name() == expected expected = "CUTLASS_MLA"
assert backend.get_name() == expected
elif name == "FLASHINFER_MLA": elif name == "FLASHINFER_MLA":
if capability[0] != 10:
pytest.skip(
"FlashInfer MLA is not supported on this platform"
)
if block_size not in [32, 64]: if block_size not in [32, 64]:
# FlashInfer MLA only supports block_size 32 or 64 # FlashInfer MLA only supports block_size 32 or 64
pytest.skip( pytest.skip(
"FlashInfer MLA only supports block_size 32 or 64" "FlashInfer MLA only supports block_size 32 or 64"
) )
else: backend = get_attn_backend(
backend = get_attn_backend( 576, torch.float16, None, block_size, use_mla=use_mla
16, torch.float16, None, block_size, use_mla=use_mla )
) expected = "FLASHINFER_MLA"
expected = "FLASHINFER_MLA" assert backend.get_name() == expected
assert backend.get_name() == expected
elif name == "FLASHMLA": elif name == "FLASHMLA":
if block_size != 64: if block_size != 64:
# FlashMLA only supports block_size == 64 # FlashMLA only supports block_size == 64
pytest.skip("FlashMLA only supports block_size 64") pytest.skip("FlashMLA only supports block_size 64")
else: from vllm.v1.attention.backends.mla.flashmla import (
from vllm.v1.attention.backends.mla.flashmla import ( is_flashmla_dense_supported,
is_flashmla_dense_supported, )
)
is_supported, _ = is_flashmla_dense_supported() is_supported, _ = is_flashmla_dense_supported()
if not is_supported: if not is_supported:
pytest.skip("FlashMLA not supported on this platform") pytest.skip("FlashMLA not supported on this platform")
else:
backend = get_attn_backend(
16, torch.float16, None, block_size, use_mla=use_mla
)
expected = name
assert backend.get_name() == expected
elif name == "FLASH_ATTN_MLA":
backend = get_attn_backend( backend = get_attn_backend(
16, torch.float16, None, block_size, use_mla=use_mla 576,
torch.float16,
None,
block_size,
use_mla=use_mla,
)
expected = name
assert backend.get_name() == expected
elif name == "FLASH_ATTN_MLA":
from vllm.attention.utils.fa_utils import (
flash_attn_supports_mla,
)
if not flash_attn_supports_mla():
pytest.skip(
"FlashAttention MLA not supported on this platform"
)
backend = get_attn_backend(
576, torch.float16, None, block_size, use_mla=use_mla
) )
expected = "FLASH_ATTN_MLA" expected = "FLASH_ATTN_MLA"
assert backend.get_name() == expected assert backend.get_name() == expected
else: else:
# TRITON_MLA or other fallback # TRITON_MLA or other fallback
backend = get_attn_backend( backend = get_attn_backend(
16, torch.float16, None, block_size, use_mla=use_mla 576, torch.float16, None, block_size, use_mla=use_mla
) )
expected = "TRITON_MLA" expected = "TRITON_MLA"
assert backend.get_name() == expected assert backend.get_name() == expected
elif name == "FLASHINFER": elif name == "FLASHINFER":
backend = get_attn_backend( backend = get_attn_backend(
16, torch.float16, None, block_size, use_mla=use_mla 64, torch.float16, None, block_size, use_mla=use_mla
) )
expected = "FLASHINFER" expected = "FLASHINFER"
assert backend.get_name() == expected assert backend.get_name() == expected

View File

@ -11,7 +11,7 @@ from unittest.mock import patch
import pytest import pytest
import torch import torch
from vllm.attention.backends.registry import _Backend from vllm.attention.backends.registry import AttentionBackendEnum
from vllm.attention.layer import MultiHeadAttention from vllm.attention.layer import MultiHeadAttention
from vllm.attention.selector import _cached_get_attn_backend from vllm.attention.selector import _cached_get_attn_backend
from vllm.platforms import current_platform from vllm.platforms import current_platform
@ -43,14 +43,14 @@ def test_mha_attn_platform(device: str):
patch("vllm.model_executor.models.vision.current_platform", CpuPlatform()), patch("vllm.model_executor.models.vision.current_platform", CpuPlatform()),
): ):
attn = MultiHeadAttention(16, 64, scale=1) attn = MultiHeadAttention(16, 64, scale=1)
assert attn.attn_backend == _Backend.TORCH_SDPA assert attn.attn_backend == AttentionBackendEnum.TORCH_SDPA
elif device == "hip": elif device == "hip":
with ( with (
patch("vllm.attention.layer.current_platform", RocmPlatform()), patch("vllm.attention.layer.current_platform", RocmPlatform()),
patch("vllm.model_executor.models.vision.current_platform", RocmPlatform()), patch("vllm.model_executor.models.vision.current_platform", RocmPlatform()),
): ):
attn = MultiHeadAttention(16, 64, scale=1) attn = MultiHeadAttention(16, 64, scale=1)
assert attn.attn_backend == _Backend.TORCH_SDPA assert attn.attn_backend == AttentionBackendEnum.TORCH_SDPA
else: else:
# Test CUDA with head_size=64 (divisible by 32) # Test CUDA with head_size=64 (divisible by 32)
# - should use vLLM's FlashAttention # - should use vLLM's FlashAttention
@ -59,7 +59,7 @@ def test_mha_attn_platform(device: str):
patch("vllm.model_executor.models.vision.current_platform", CudaPlatform()), patch("vllm.model_executor.models.vision.current_platform", CudaPlatform()),
): ):
attn = MultiHeadAttention(16, 64, scale=1) attn = MultiHeadAttention(16, 64, scale=1)
assert attn.attn_backend == _Backend.FLASH_ATTN assert attn.attn_backend == AttentionBackendEnum.FLASH_ATTN
# Test CUDA with head_size=72 (not divisible by 32) # Test CUDA with head_size=72 (not divisible by 32)
# - with upstream FA not available # - with upstream FA not available
@ -73,7 +73,7 @@ def test_mha_attn_platform(device: str):
), ),
): ):
attn = MultiHeadAttention(16, 72, scale=1) attn = MultiHeadAttention(16, 72, scale=1)
assert attn.attn_backend == _Backend.XFORMERS assert attn.attn_backend == AttentionBackendEnum.XFORMERS
# Test CUDA with head_size=72 (not divisible by 32) # Test CUDA with head_size=72 (not divisible by 32)
# - with upstream FA available # - with upstream FA available
@ -96,7 +96,7 @@ def test_mha_attn_platform(device: str):
), ),
): ):
attn = MultiHeadAttention(16, 72, scale=1) attn = MultiHeadAttention(16, 72, scale=1)
assert attn.attn_backend == _Backend.FLASH_ATTN assert attn.attn_backend == AttentionBackendEnum.FLASH_ATTN
def ref_attention( def ref_attention(

View File

@ -93,6 +93,17 @@ def can_initialize(
"pickle error when loading `transformers.models.auto.CONFIG_MAPPING`" "pickle error when loading `transformers.models.auto.CONFIG_MAPPING`"
) )
if model_arch == "DeepseekV32ForCausalLM":
from vllm.platforms import current_platform
capability = current_platform.get_device_capability()
if capability and capability.major < 9:
pytest.skip(
f"DeepseekV32 requires Hopper (9.0+) or Blackwell (10.0+) "
f"for FLASHMLA_SPARSE backend. Current device has compute "
f"capability {capability.major}.{capability.minor}"
)
with ( with (
patch.object(V1EngineCore, "_initialize_kv_caches", _initialize_kv_caches_v1), patch.object(V1EngineCore, "_initialize_kv_caches", _initialize_kv_caches_v1),
monkeypatch.context() as m, monkeypatch.context() as m,

View File

@ -15,7 +15,7 @@ from tests.v1.attention.utils import (
create_vllm_config, create_vllm_config,
try_get_attention_backend, try_get_attention_backend,
) )
from vllm.attention.backends.registry import _Backend from vllm.attention.backends.registry import AttentionBackendEnum
from vllm.config import ModelConfig from vllm.config import ModelConfig
from vllm.platforms import current_platform from vllm.platforms import current_platform
from vllm.utils.math_utils import cdiv from vllm.utils.math_utils import cdiv
@ -27,11 +27,11 @@ from vllm.v1.attention.backends.utils import (
from vllm.v1.kv_cache_interface import FullAttentionSpec from vllm.v1.kv_cache_interface import FullAttentionSpec
BACKENDS_TO_TEST = [ BACKENDS_TO_TEST = [
_Backend.FLASH_ATTN, AttentionBackendEnum.FLASH_ATTN,
_Backend.FLASHINFER, AttentionBackendEnum.FLASHINFER,
_Backend.FLEX_ATTENTION, AttentionBackendEnum.FLEX_ATTENTION,
_Backend.TRITON_ATTN, AttentionBackendEnum.TRITON_ATTN,
_Backend.TREE_ATTN, AttentionBackendEnum.TREE_ATTN,
"FLEX_ATTENTION_SLOW", "FLEX_ATTENTION_SLOW",
] ]
@ -39,7 +39,7 @@ BACKENDS_TO_TEST = [
try: try:
import flashinfer # noqa: F401 import flashinfer # noqa: F401
except ImportError: except ImportError:
BACKENDS_TO_TEST.remove(_Backend.FLASHINFER) BACKENDS_TO_TEST.remove(AttentionBackendEnum.FLASHINFER)
def _convert_dtype_to_torch(dtype): def _convert_dtype_to_torch(dtype):
@ -192,7 +192,7 @@ class MockAttentionLayer:
def run_attention_backend( def run_attention_backend(
backend: _Backend, backend: AttentionBackendEnum,
kv_cache_spec: FullAttentionSpec, kv_cache_spec: FullAttentionSpec,
layer_names: list[str], layer_names: list[str],
vllm_config, vllm_config,
@ -211,13 +211,13 @@ def run_attention_backend(
use_direct_block_mask = is_torch_equal_or_newer("2.9.0.dev0") use_direct_block_mask = is_torch_equal_or_newer("2.9.0.dev0")
if backend == "FLEX_ATTENTION_SLOW": if backend == "FLEX_ATTENTION_SLOW":
actual_backend = _Backend.FLEX_ATTENTION actual_backend = AttentionBackendEnum.FLEX_ATTENTION
use_direct_block_mask = False use_direct_block_mask = False
builder_cls, impl_cls = try_get_attention_backend(actual_backend) builder_cls, impl_cls = try_get_attention_backend(actual_backend)
# Mock flashinfer's get_per_layer_parameters if needed # Mock flashinfer's get_per_layer_parameters if needed
if actual_backend == _Backend.FLASHINFER: if actual_backend == AttentionBackendEnum.FLASHINFER:
import unittest.mock import unittest.mock
from vllm.v1.attention.backends.utils import PerLayerParameters from vllm.v1.attention.backends.utils import PerLayerParameters
@ -246,7 +246,7 @@ def run_attention_backend(
else: else:
# Build metadata # Build metadata
builder = builder_cls(kv_cache_spec, layer_names, vllm_config, device) builder = builder_cls(kv_cache_spec, layer_names, vllm_config, device)
if actual_backend == _Backend.FLEX_ATTENTION: if actual_backend == AttentionBackendEnum.FLEX_ATTENTION:
builder.direct_build = use_direct_block_mask builder.direct_build = use_direct_block_mask
attn_metadata = builder.build( attn_metadata = builder.build(
common_prefix_len=0, common_prefix_len=0,
@ -289,7 +289,7 @@ def run_attention_backend(
def _test_backend_correctness( def _test_backend_correctness(
batch_spec: BatchSpec, batch_spec: BatchSpec,
model: str, model: str,
backend_to_test: list[_Backend | str], backend_to_test: list[AttentionBackendEnum | str],
mask_mod, mask_mod,
*, *,
block_size: int = 16, block_size: int = 16,
@ -455,17 +455,20 @@ def _test_backend_correctness(
# Select the appropriate KV cache format for each backend # Select the appropriate KV cache format for each backend
kv_cache_for_backend = kv_cache kv_cache_for_backend = kv_cache
reset_kv_cache_layout = False reset_kv_cache_layout = False
if backend_name in (_Backend.FLASHINFER, _Backend.TRITON_ATTN): if backend_name in (
AttentionBackendEnum.FLASHINFER,
AttentionBackendEnum.TRITON_ATTN,
):
kv_cache_for_backend = kv_cache.transpose(0, 1) kv_cache_for_backend = kv_cache.transpose(0, 1)
if backend_name == _Backend.FLASHINFER: if backend_name == AttentionBackendEnum.FLASHINFER:
# For FlashInfer default to HND layout and # For FlashInfer default to HND layout and
kv_cache_for_backend = ( kv_cache_for_backend = (
kv_cache_for_backend.transpose(2, 3).contiguous().transpose(2, 3) kv_cache_for_backend.transpose(2, 3).contiguous().transpose(2, 3)
) )
set_kv_cache_layout("HND") set_kv_cache_layout("HND")
reset_kv_cache_layout = True reset_kv_cache_layout = True
elif backend_name == _Backend.TRITON_ATTN: elif backend_name == AttentionBackendEnum.TRITON_ATTN:
kv_cache_for_backend = kv_cache_for_backend.contiguous() kv_cache_for_backend = kv_cache_for_backend.contiguous()
try: try:
@ -547,7 +550,9 @@ def test_causal_backend_correctness(
batch_spec = BATCH_SPECS[batch_spec_name] batch_spec = BATCH_SPECS[batch_spec_name]
LARGE_BLOCK_BACKENDS = ( LARGE_BLOCK_BACKENDS = (
[_Backend.FLEX_ATTENTION] if is_torch_equal_or_newer("2.9.0.dev0") else [] [AttentionBackendEnum.FLEX_ATTENTION]
if is_torch_equal_or_newer("2.9.0.dev0")
else []
) )
SMALL_BLOCK_BACKENDS = [ SMALL_BLOCK_BACKENDS = [
x for x in BACKENDS_TO_TEST if x not in LARGE_BLOCK_BACKENDS x for x in BACKENDS_TO_TEST if x not in LARGE_BLOCK_BACKENDS
@ -573,9 +578,9 @@ def test_causal_backend_correctness(
SLIDING_WINDOW_BACKENDS_TO_TEST = [ SLIDING_WINDOW_BACKENDS_TO_TEST = [
_Backend.FLASH_ATTN, AttentionBackendEnum.FLASH_ATTN,
_Backend.FLEX_ATTENTION, AttentionBackendEnum.FLEX_ATTENTION,
_Backend.TRITON_ATTN, AttentionBackendEnum.TRITON_ATTN,
"FLEX_ATTENTION_SLOW", "FLEX_ATTENTION_SLOW",
] ]
@ -612,7 +617,9 @@ def test_sliding_window_backend_correctness(
) )
LARGE_BLOCK_BACKENDS = ( LARGE_BLOCK_BACKENDS = (
[_Backend.FLEX_ATTENTION] if is_torch_equal_or_newer("2.9.0.dev0") else [] [AttentionBackendEnum.FLEX_ATTENTION]
if is_torch_equal_or_newer("2.9.0.dev0")
else []
) )
SMALL_BLOCK_BACKENDS = [ SMALL_BLOCK_BACKENDS = [
x for x in SLIDING_WINDOW_BACKENDS_TO_TEST if x not in LARGE_BLOCK_BACKENDS x for x in SLIDING_WINDOW_BACKENDS_TO_TEST if x not in LARGE_BLOCK_BACKENDS

View File

@ -18,12 +18,11 @@ from tests.v1.attention.utils import (
try_get_attention_backend, try_get_attention_backend,
) )
from vllm import _custom_ops as ops from vllm import _custom_ops as ops
from vllm.attention.backends.registry import _Backend, backend_to_class_str from vllm.attention.backends.registry import AttentionBackendEnum
from vllm.attention.ops.flashmla import is_flashmla_dense_supported from vllm.attention.ops.flashmla import is_flashmla_dense_supported
from vllm.attention.utils.fa_utils import flash_attn_supports_mla from vllm.attention.utils.fa_utils import flash_attn_supports_mla
from vllm.config.vllm import set_current_vllm_config from vllm.config.vllm import set_current_vllm_config
from vllm.model_executor.layers.attention_layer_base import AttentionLayerBase from vllm.model_executor.layers.attention_layer_base import AttentionLayerBase
from vllm.utils.import_utils import resolve_obj_by_qualname
from vllm.utils.math_utils import cdiv from vllm.utils.math_utils import cdiv
from vllm.utils.torch_utils import STR_DTYPE_TO_TORCH_DTYPE from vllm.utils.torch_utils import STR_DTYPE_TO_TORCH_DTYPE
from vllm.v1.attention.backends.mla.common import QueryLenSupport from vllm.v1.attention.backends.mla.common import QueryLenSupport
@ -31,25 +30,25 @@ from vllm.v1.attention.backends.utils import CommonAttentionMetadata
from vllm.v1.kv_cache_interface import FullAttentionSpec from vllm.v1.kv_cache_interface import FullAttentionSpec
BACKENDS_TO_TEST = [ BACKENDS_TO_TEST = [
_Backend.CUTLASS_MLA, AttentionBackendEnum.CUTLASS_MLA,
_Backend.FLASHMLA, AttentionBackendEnum.FLASHMLA,
_Backend.FLASH_ATTN_MLA, AttentionBackendEnum.FLASH_ATTN_MLA,
_Backend.FLASHINFER_MLA, AttentionBackendEnum.FLASHINFER_MLA,
_Backend.TRITON_MLA, AttentionBackendEnum.TRITON_MLA,
] ]
# Remove sm100 backends from the list if not using sm100 # Remove sm100 backends from the list if not using sm100
if not torch.cuda.is_available() or torch.cuda.get_device_properties(0).major < 10: if not torch.cuda.is_available() or torch.cuda.get_device_properties(0).major < 10:
BACKENDS_TO_TEST.remove(_Backend.CUTLASS_MLA) BACKENDS_TO_TEST.remove(AttentionBackendEnum.CUTLASS_MLA)
BACKENDS_TO_TEST.remove(_Backend.FLASHINFER_MLA) BACKENDS_TO_TEST.remove(AttentionBackendEnum.FLASHINFER_MLA)
# Remove FLASH_ATTN_MLA from the list if not supported # Remove FLASH_ATTN_MLA from the list if not supported
if not flash_attn_supports_mla(): if not flash_attn_supports_mla():
BACKENDS_TO_TEST.remove(_Backend.FLASH_ATTN_MLA) BACKENDS_TO_TEST.remove(AttentionBackendEnum.FLASH_ATTN_MLA)
# Remove FLASHMLA from the list if not supported # Remove FLASHMLA from the list if not supported
if not is_flashmla_dense_supported()[0]: if not is_flashmla_dense_supported()[0]:
BACKENDS_TO_TEST.remove(_Backend.FLASHMLA) BACKENDS_TO_TEST.remove(AttentionBackendEnum.FLASHMLA)
SPEC_DECODE_BACKENDS = [] SPEC_DECODE_BACKENDS = []
for backend in BACKENDS_TO_TEST: for backend in BACKENDS_TO_TEST:
@ -62,9 +61,7 @@ for backend in BACKENDS_TO_TEST:
BACKEND_BLOCK_SIZES = {} BACKEND_BLOCK_SIZES = {}
for backend in BACKENDS_TO_TEST: for backend in BACKENDS_TO_TEST:
backend_class_str = backend_to_class_str(backend) supported_sizes = backend.get_class().supported_kernel_block_sizes
backend_class = resolve_obj_by_qualname(backend_class_str)
supported_sizes = backend_class.get_supported_kernel_block_size()
if supported_sizes: if supported_sizes:
default_size = supported_sizes[0] default_size = supported_sizes[0]
block_size = ( block_size = (
@ -291,7 +288,7 @@ class MockMLAAttentionLayer(AttentionLayerBase):
def run_attention_backend( def run_attention_backend(
backend: _Backend, backend: AttentionBackendEnum,
kv_cache_spec: FullAttentionSpec, kv_cache_spec: FullAttentionSpec,
layer_names: list[str], layer_names: list[str],
vllm_config, vllm_config,
@ -813,7 +810,7 @@ def test_backend_correctness(
# Create a summary for the single-line failure message # Create a summary for the single-line failure message
backend_names = [] backend_names = []
for f in failures: for f in failures:
if "[_Backend." in f: if "[AttentionBackendEnum." in f:
backend_name = f.split("[")[1].split("]")[0] backend_name = f.split("[")[1].split("]")[0]
backend_names.append(backend_name) backend_names.append(backend_name)

View File

@ -8,7 +8,7 @@ import pytest
import torch import torch
from vllm.attention.backends.abstract import AttentionImpl from vllm.attention.backends.abstract import AttentionImpl
from vllm.attention.backends.registry import _Backend, backend_to_class_str from vllm.attention.backends.registry import AttentionBackendEnum
from vllm.config import ( from vllm.config import (
CacheConfig, CacheConfig,
CompilationConfig, CompilationConfig,
@ -20,7 +20,6 @@ from vllm.config import (
VllmConfig, VllmConfig,
) )
from vllm.config.model import ModelDType from vllm.config.model import ModelDType
from vllm.utils.import_utils import resolve_obj_by_qualname
from vllm.v1.attention.backends.utils import ( from vllm.v1.attention.backends.utils import (
AttentionMetadataBuilder, AttentionMetadataBuilder,
CommonAttentionMetadata, CommonAttentionMetadata,
@ -120,15 +119,14 @@ def create_common_attn_metadata(
def try_get_attention_backend( def try_get_attention_backend(
backend: _Backend, backend: AttentionBackendEnum,
) -> tuple[type[AttentionMetadataBuilder], type[AttentionImpl]]: ) -> tuple[type[AttentionMetadataBuilder], type[AttentionImpl]]:
"""Try to get the attention backend class, skipping test if not found.""" """Try to get the attention backend class, skipping test if not found."""
backend_class_str = backend_to_class_str(backend)
try: try:
backend_class = resolve_obj_by_qualname(backend_class_str) backend_class = backend.get_class()
return backend_class.get_builder_cls(), backend_class.get_impl_cls() return backend_class.get_builder_cls(), backend_class.get_impl_cls()
except ImportError as e: except ImportError as e:
pytest.skip(f"{backend_class_str} not available: {e}") pytest.skip(f"{backend.name} not available: {e}")
raise AssertionError("unreachable") from None raise AssertionError("unreachable") from None

View File

@ -13,7 +13,7 @@ from tests.v1.attention.utils import (
create_standard_kv_cache_spec, create_standard_kv_cache_spec,
try_get_attention_backend, try_get_attention_backend,
) )
from vllm.attention.backends.registry import _Backend from vllm.attention.backends.registry import AttentionBackendEnum
from vllm.config import ( from vllm.config import (
CacheConfig, CacheConfig,
DeviceConfig, DeviceConfig,
@ -534,11 +534,17 @@ def test_propose(method, attn_backend, num_speculative_tokens, monkeypatch):
sampling_metadata = mock.MagicMock() sampling_metadata = mock.MagicMock()
if attn_backend == "FLASH_ATTN": if attn_backend == "FLASH_ATTN":
attn_metadata_builder_cls, _ = try_get_attention_backend(_Backend.FLASH_ATTN) attn_metadata_builder_cls, _ = try_get_attention_backend(
AttentionBackendEnum.FLASH_ATTN
)
elif attn_backend == "TRITON_ATTN": elif attn_backend == "TRITON_ATTN":
attn_metadata_builder_cls, _ = try_get_attention_backend(_Backend.TRITON_ATTN) attn_metadata_builder_cls, _ = try_get_attention_backend(
AttentionBackendEnum.TRITON_ATTN
)
elif attn_backend == "TREE_ATTN": elif attn_backend == "TREE_ATTN":
attn_metadata_builder_cls, _ = try_get_attention_backend(_Backend.TREE_ATTN) attn_metadata_builder_cls, _ = try_get_attention_backend(
AttentionBackendEnum.TREE_ATTN
)
else: else:
raise ValueError(f"Unsupported attention backend: {attn_backend}") raise ValueError(f"Unsupported attention backend: {attn_backend}")
@ -673,7 +679,9 @@ def test_propose_tree(spec_token_tree):
proposer.attn_layer_names = ["layer.0"] proposer.attn_layer_names = ["layer.0"]
# Get the tree attention metadata builder. # Get the tree attention metadata builder.
attn_metadata_builder_cls, _ = try_get_attention_backend(_Backend.TREE_ATTN) attn_metadata_builder_cls, _ = try_get_attention_backend(
AttentionBackendEnum.TREE_ATTN
)
attn_metadata_builder = attn_metadata_builder_cls( attn_metadata_builder = attn_metadata_builder_cls(
kv_cache_spec=create_standard_kv_cache_spec(proposer.vllm_config), kv_cache_spec=create_standard_kv_cache_spec(proposer.vllm_config),
layer_names=proposer.attn_layer_names, layer_names=proposer.attn_layer_names,

View File

@ -12,7 +12,7 @@ from tests.v1.attention.utils import (
create_standard_kv_cache_spec, create_standard_kv_cache_spec,
try_get_attention_backend, try_get_attention_backend,
) )
from vllm.attention.backends.registry import _Backend from vllm.attention.backends.registry import AttentionBackendEnum
from vllm.config import ( from vllm.config import (
CacheConfig, CacheConfig,
DeviceConfig, DeviceConfig,
@ -177,7 +177,9 @@ def test_mtp_propose(num_speculative_tokens, monkeypatch):
sampling_metadata = mock.MagicMock() sampling_metadata = mock.MagicMock()
# Setup attention metadata # Setup attention metadata
attn_metadata_builder_cls, _ = try_get_attention_backend(_Backend.FLASH_ATTN) attn_metadata_builder_cls, _ = try_get_attention_backend(
AttentionBackendEnum.FLASH_ATTN
)
attn_metadata_builder = attn_metadata_builder_cls( attn_metadata_builder = attn_metadata_builder_cls(
kv_cache_spec=create_standard_kv_cache_spec(proposer.vllm_config), kv_cache_spec=create_standard_kv_cache_spec(proposer.vllm_config),

View File

@ -10,7 +10,7 @@ from tests.v1.attention.utils import (
create_vllm_config, create_vllm_config,
try_get_attention_backend, try_get_attention_backend,
) )
from vllm.attention.backends.registry import _Backend from vllm.attention.backends.registry import AttentionBackendEnum
from vllm.config import ParallelConfig, SpeculativeConfig from vllm.config import ParallelConfig, SpeculativeConfig
from vllm.v1.attention.backends.utils import CommonAttentionMetadata from vllm.v1.attention.backends.utils import CommonAttentionMetadata
@ -35,7 +35,7 @@ def forward_attention(
block_table: torch.Tensor, block_table: torch.Tensor,
slot_mapping: torch.Tensor, slot_mapping: torch.Tensor,
seqlen_k: int, seqlen_k: int,
backend: _Backend, backend: AttentionBackendEnum,
spec_token_tree: str | None = None, spec_token_tree: str | None = None,
num_spec_tokens: int = 0, num_spec_tokens: int = 0,
) -> torch.Tensor: ) -> torch.Tensor:
@ -241,7 +241,7 @@ def test_tree_attn_correctness() -> None:
block_table=block_table, block_table=block_table,
slot_mapping=tree_slot_mapping, slot_mapping=tree_slot_mapping,
seqlen_k=seqlen_k, seqlen_k=seqlen_k,
backend=_Backend.TREE_ATTN, backend=AttentionBackendEnum.TREE_ATTN,
spec_token_tree=spec_token_tree, spec_token_tree=spec_token_tree,
num_spec_tokens=tree_size_q - 1, num_spec_tokens=tree_size_q - 1,
).view(batch_size, -1, num_heads, dim_per_head) ).view(batch_size, -1, num_heads, dim_per_head)
@ -278,7 +278,7 @@ def test_tree_attn_correctness() -> None:
block_table=block_table, block_table=block_table,
slot_mapping=branch_slot_mapping, slot_mapping=branch_slot_mapping,
seqlen_k=sequence_position + q_len, seqlen_k=sequence_position + q_len,
backend=_Backend.FLASH_ATTN, backend=AttentionBackendEnum.FLASH_ATTN,
).view(batch_size, -1, num_heads, dim_per_head) ).view(batch_size, -1, num_heads, dim_per_head)
# Compare the outputs. # Compare the outputs.

View File

@ -185,9 +185,7 @@ def _make_mock_backend_for_kernel_block_size(
supported_sizes: list[int | MultipleOf], supported_sizes: list[int | MultipleOf],
): ):
class _MockBackend: class _MockBackend:
@staticmethod supported_kernel_block_sizes = supported_sizes
def get_supported_kernel_block_size():
return supported_sizes
return _MockBackend() return _MockBackend()
@ -466,13 +464,20 @@ def test_kv_cache_stride_order(monkeypatch, model_runner):
# This test checks if GPUModelRunner initializes correctly when an attention # This test checks if GPUModelRunner initializes correctly when an attention
# backend enforces a non-default KV cache stride order. # backend enforces a non-default KV cache stride order.
n_heads = model_runner.model_config.get_num_kv_heads(model_runner.parallel_config) n_heads = model_runner.model_config.get_num_kv_heads(model_runner.parallel_config)
expected_kv_cache_shape = [ head_size = model_runner.model_config.get_head_size()
2,
NUM_BLOCKS, # Get the expected shape from the backend's get_kv_cache_shape method
BLOCK_SIZE, # to ensure compatibility with different backends (triton vs flexattention)
n_heads, attn_backend = None
model_runner.model_config.get_head_size(), for attn_group in model_runner._attn_group_iterator():
] attn_backend = attn_group.backend
break
assert attn_backend is not None, "No attention backend found"
expected_kv_cache_shape = list(
attn_backend.get_kv_cache_shape(NUM_BLOCKS, BLOCK_SIZE, n_heads, head_size)
)
# TODO mla test # TODO mla test
default_stride = tuple(range(5)) default_stride = tuple(range(5))
# Permutation that gets you back to expected kv shape # Permutation that gets you back to expected kv shape

View File

@ -2,13 +2,18 @@
# SPDX-FileCopyrightText: Copyright contributors to the vLLM project # SPDX-FileCopyrightText: Copyright contributors to the vLLM project
from abc import ABC, abstractmethod from abc import ABC, abstractmethod
from typing import Generic, Protocol, TypeVar from typing import TYPE_CHECKING, ClassVar, Generic, Protocol, TypeVar, get_args
import torch import torch
from vllm.model_executor.layers.linear import ColumnParallelLinear from vllm.model_executor.layers.linear import ColumnParallelLinear
from vllm.model_executor.layers.quantization.utils.quant_utils import QuantKey from vllm.model_executor.layers.quantization.utils.quant_utils import QuantKey
if TYPE_CHECKING:
from vllm.config.cache import CacheDType
from vllm.platforms.interface import DeviceCapability
from vllm.v1.attention.backends.utils import KVCacheLayoutType
class AttentionType: class AttentionType:
""" """
@ -40,6 +45,9 @@ class AttentionBackend(ABC):
# calling the custom op. When piecewise cudagraph is enabled, this # calling the custom op. When piecewise cudagraph is enabled, this
# makes sure the output tensor is allocated inside the cudagraph. # makes sure the output tensor is allocated inside the cudagraph.
accept_output_buffer: bool = False accept_output_buffer: bool = False
supported_dtypes: ClassVar[list[torch.dtype]] = [torch.float16, torch.bfloat16]
supported_kernel_block_sizes: ClassVar[list[int | MultipleOf]] = [MultipleOf(1)]
supported_kv_cache_dtypes: ClassVar[list["CacheDType"]] = ["auto"]
@staticmethod @staticmethod
@abstractmethod @abstractmethod
@ -51,10 +59,6 @@ class AttentionBackend(ABC):
def get_impl_cls() -> type["AttentionImpl"]: def get_impl_cls() -> type["AttentionImpl"]:
raise NotImplementedError raise NotImplementedError
@classmethod
def get_supported_kernel_block_size(cls) -> list[int | MultipleOf]:
return cls.get_impl_cls().get_supported_kernel_block_size()
@staticmethod @staticmethod
@abstractmethod @abstractmethod
def get_builder_cls(): # -> Type["AttentionMetadataBuilder"]: def get_builder_cls(): # -> Type["AttentionMetadataBuilder"]:
@ -79,6 +83,136 @@ class AttentionBackend(ABC):
def full_cls_name(cls) -> tuple[str, str]: def full_cls_name(cls) -> tuple[str, str]:
return (cls.__module__, cls.__qualname__) return (cls.__module__, cls.__qualname__)
@classmethod
def get_supported_head_sizes(cls) -> list[int]:
return []
@classmethod
def supports_head_size(cls, head_size: int) -> bool:
supported_head_sizes = cls.get_supported_head_sizes()
return (not supported_head_sizes) or head_size in supported_head_sizes
@classmethod
def supports_dtype(cls, dtype: torch.dtype) -> bool:
return dtype in cls.supported_dtypes
@classmethod
def supports_kv_cache_dtype(cls, kv_cache_dtype: "CacheDType | None") -> bool:
if kv_cache_dtype is None:
return True
return (not cls.supported_kv_cache_dtypes) or (
kv_cache_dtype in cls.supported_kv_cache_dtypes
)
@classmethod
def supports_block_size(cls, block_size: int | None) -> bool:
from vllm.config.cache import BlockSize
if block_size is None:
return True
valid_sizes = get_args(BlockSize)
if block_size not in valid_sizes:
return False
if not cls.supported_kernel_block_sizes:
return True
for supported_size in cls.supported_kernel_block_sizes:
is_multiple_of = (
isinstance(supported_size, MultipleOf)
and block_size % supported_size.base == 0
)
is_int_equal = (
isinstance(supported_size, int) and block_size == supported_size
)
if is_multiple_of or is_int_equal:
return True
return False
@classmethod
def is_mla(cls) -> bool:
return False
@classmethod
def supports_sink(cls) -> bool:
return False
@classmethod
def is_sparse(cls) -> bool:
return False
@classmethod
def supports_compute_capability(cls, capability: "DeviceCapability") -> bool:
return True
@classmethod
def supports_combination(
cls,
head_size: int,
dtype: torch.dtype,
kv_cache_dtype: "CacheDType | None",
block_size: int | None,
use_mla: bool,
has_sink: bool,
use_sparse: bool,
device_capability: "DeviceCapability",
) -> str | None:
return None
@classmethod
def validate_configuration(
cls,
head_size: int,
dtype: torch.dtype,
kv_cache_dtype: "CacheDType | None",
block_size: int | None,
use_mla: bool,
has_sink: bool,
use_sparse: bool,
device_capability: "DeviceCapability",
) -> list[str]:
invalid_reasons = []
if not cls.supports_head_size(head_size):
invalid_reasons.append("head_size not supported")
if not cls.supports_dtype(dtype):
invalid_reasons.append("dtype not supported")
if not cls.supports_kv_cache_dtype(kv_cache_dtype):
invalid_reasons.append("kv_cache_dtype not supported")
if not cls.supports_block_size(block_size):
invalid_reasons.append("block_size not supported")
if use_mla != cls.is_mla():
if use_mla:
invalid_reasons.append("MLA not supported")
else:
invalid_reasons.append("non-MLA not supported")
if has_sink and not cls.supports_sink():
invalid_reasons.append("sink setting not supported")
if use_sparse != cls.is_sparse():
if use_sparse:
invalid_reasons.append("sparse not supported")
else:
invalid_reasons.append("non-sparse not supported")
if not cls.supports_compute_capability(device_capability):
invalid_reasons.append("compute capability not supported")
combination_reason = cls.supports_combination(
head_size,
dtype,
kv_cache_dtype,
block_size,
use_mla,
has_sink,
use_sparse,
device_capability,
)
if combination_reason is not None:
invalid_reasons.append(combination_reason)
return invalid_reasons
@classmethod
def get_required_kv_cache_layout(cls) -> "KVCacheLayoutType | None":
return None
class AttentionMetadata: class AttentionMetadata:
pass pass
@ -151,11 +285,6 @@ class AttentionImpl(ABC, Generic[T]):
) -> None: ) -> None:
raise NotImplementedError raise NotImplementedError
@staticmethod
def get_supported_kernel_block_size() -> list[int | MultipleOf]:
# TODO: implement this function for all backends.
return [MultipleOf(1)]
@abstractmethod @abstractmethod
def forward( def forward(
self, self,

View File

@ -3,108 +3,192 @@
"""Attention backend registry""" """Attention backend registry"""
import enum import enum
from collections.abc import Callable
from typing import TYPE_CHECKING, cast
from vllm.logger import init_logger
from vllm.utils.import_utils import resolve_obj_by_qualname from vllm.utils.import_utils import resolve_obj_by_qualname
if TYPE_CHECKING:
from vllm.attention.backends.abstract import AttentionBackend
class _Backend(enum.Enum): logger = init_logger(__name__)
FLASH_ATTN = enum.auto()
TRITON_ATTN = enum.auto()
XFORMERS = enum.auto()
ROCM_ATTN = enum.auto()
ROCM_AITER_MLA = enum.auto()
ROCM_AITER_FA = enum.auto() # used for ViT attn backend
TORCH_SDPA = enum.auto()
FLASHINFER = enum.auto()
FLASHINFER_MLA = enum.auto()
TRITON_MLA = enum.auto()
CUTLASS_MLA = enum.auto()
FLASHMLA = enum.auto()
FLASHMLA_SPARSE = enum.auto()
FLASH_ATTN_MLA = enum.auto()
PALLAS = enum.auto()
IPEX = enum.auto()
NO_ATTENTION = enum.auto()
FLEX_ATTENTION = enum.auto()
TREE_ATTN = enum.auto()
ROCM_AITER_UNIFIED_ATTN = enum.auto()
BACKEND_MAP = { class _AttentionBackendEnumMeta(enum.EnumMeta):
_Backend.FLASH_ATTN: "vllm.v1.attention.backends.flash_attn.FlashAttentionBackend", # noqa: E501 """Metaclass for AttentionBackendEnum to provide better error messages."""
_Backend.TRITON_ATTN: "vllm.v1.attention.backends.triton_attn.TritonAttentionBackend", # noqa: E501
_Backend.XFORMERS: "vllm.v1.attention.backends.xformers.XFormersAttentionBackend", # noqa: E501 def __getitem__(cls, name: str):
_Backend.ROCM_ATTN: "vllm.v1.attention.backends.rocm_attn.RocmAttentionBackend", # noqa: E501 """Get backend by name with helpful error messages."""
_Backend.ROCM_AITER_MLA: "vllm.v1.attention.backends.mla.rocm_aiter_mla.AiterMLABackend", # noqa: E501 try:
_Backend.ROCM_AITER_FA: "vllm.v1.attention.backends.rocm_aiter_fa.AiterFlashAttentionBackend", # noqa: E501 return super().__getitem__(name)
_Backend.TORCH_SDPA: "vllm.v1.attention.backends.cpu_attn.TorchSDPABackend", # noqa: E501 except KeyError:
_Backend.FLASHINFER: "vllm.v1.attention.backends.flashinfer.FlashInferBackend", # noqa: E501 members = cast("dict[str, AttentionBackendEnum]", cls.__members__).values()
_Backend.FLASHINFER_MLA: "vllm.v1.attention.backends.mla.flashinfer_mla.FlashInferMLABackend", # noqa: E501 valid_backends = ", ".join(m.name for m in members)
_Backend.TRITON_MLA: "vllm.v1.attention.backends.mla.triton_mla.TritonMLABackend", # noqa: E501 raise ValueError(
_Backend.CUTLASS_MLA: "vllm.v1.attention.backends.mla.cutlass_mla.CutlassMLABackend", # noqa: E501 f"Unknown attention backend: '{name}'. "
_Backend.FLASHMLA: "vllm.v1.attention.backends.mla.flashmla.FlashMLABackend", # noqa: E501 f"Valid options are: {valid_backends}"
_Backend.FLASHMLA_SPARSE: "vllm.v1.attention.backends.mla.flashmla_sparse.FlashMLASparseBackend", # noqa: E501 ) from None
_Backend.FLASH_ATTN_MLA: "vllm.v1.attention.backends.mla.flashattn_mla.FlashAttnMLABackend", # noqa: E501
_Backend.PALLAS: "vllm.v1.attention.backends.pallas.PallasAttentionBackend", # noqa: E501
_Backend.FLEX_ATTENTION: "vllm.v1.attention.backends.flex_attention.FlexAttentionBackend", # noqa: E501
_Backend.TREE_ATTN: "vllm.v1.attention.backends.tree_attn.TreeAttentionBackend", # noqa: E501
_Backend.ROCM_AITER_UNIFIED_ATTN: "vllm.v1.attention.backends.rocm_aiter_unified_attn.RocmAiterUnifiedAttentionBackend", # noqa: E501
}
def register_attn_backend(backend: _Backend, class_path: str | None = None): class AttentionBackendEnum(enum.Enum, metaclass=_AttentionBackendEnumMeta):
"""Enumeration of all supported attention backends.
The enum value is the default class path, but this can be overridden
at runtime using register_backend().
To get the actual backend class (respecting overrides), use:
backend.get_class()
""" """
Decorator: register a custom attention backend into BACKEND_MAPPING.
- If class_path is provided, use it.
- Otherwise, auto-generate from the class object.
Validation: only checks if 'backend' is a valid _Backend enum member.
Overwriting existing mappings is allowed. This enables other hardware
platforms to plug in custom out-of-tree backends.
"""
if not isinstance(backend, _Backend):
raise ValueError(f"{backend} is not a valid _Backend enum value.")
def decorator(cls): FLASH_ATTN = "vllm.v1.attention.backends.flash_attn.FlashAttentionBackend"
path = class_path or f"{cls.__module__}.{cls.__qualname__}" TRITON_ATTN = "vllm.v1.attention.backends.triton_attn.TritonAttentionBackend"
BACKEND_MAP[backend] = path XFORMERS = "vllm.v1.attention.backends.xformers.XFormersAttentionBackend"
ROCM_ATTN = "vllm.v1.attention.backends.rocm_attn.RocmAttentionBackend"
ROCM_AITER_MLA = "vllm.v1.attention.backends.mla.rocm_aiter_mla.AiterMLABackend"
ROCM_AITER_FA = (
"vllm.v1.attention.backends.rocm_aiter_fa.AiterFlashAttentionBackend"
)
TORCH_SDPA = "vllm.v1.attention.backends.cpu_attn.TorchSDPABackend"
FLASHINFER = "vllm.v1.attention.backends.flashinfer.FlashInferBackend"
FLASHINFER_MLA = (
"vllm.v1.attention.backends.mla.flashinfer_mla.FlashInferMLABackend"
)
TRITON_MLA = "vllm.v1.attention.backends.mla.triton_mla.TritonMLABackend"
CUTLASS_MLA = "vllm.v1.attention.backends.mla.cutlass_mla.CutlassMLABackend"
FLASHMLA = "vllm.v1.attention.backends.mla.flashmla.FlashMLABackend"
FLASHMLA_SPARSE = (
"vllm.v1.attention.backends.mla.flashmla_sparse.FlashMLASparseBackend"
)
FLASH_ATTN_MLA = "vllm.v1.attention.backends.mla.flashattn_mla.FlashAttnMLABackend"
PALLAS = "vllm.v1.attention.backends.pallas.PallasAttentionBackend"
IPEX = "vllm.v1.attention.backends.ipex.IpexAttentionBackend"
NO_ATTENTION = "vllm.v1.attention.backends.no_attention.NoAttentionBackend"
FLEX_ATTENTION = "vllm.v1.attention.backends.flex_attention.FlexAttentionBackend"
TREE_ATTN = "vllm.v1.attention.backends.tree_attn.TreeAttentionBackend"
ROCM_AITER_UNIFIED_ATTN = (
"vllm.v1.attention.backends.rocm_aiter_unified_attn."
"RocmAiterUnifiedAttentionBackend"
)
# Placeholder for third-party/custom backends - must be registered before use
CUSTOM = ""
def get_path(self, include_classname: bool = True) -> str:
"""Get the class path for this backend (respects overrides).
Returns:
The fully qualified class path string
Raises:
ValueError: If Backend.CUSTOM is used without being registered
"""
path = _OVERRIDES.get(self, self.value)
if not path:
raise ValueError(
f"Backend {self.name} must be registered before use. "
f"Use register_backend(Backend.{self.name}, 'your.module.YourClass')"
)
if not include_classname:
path = path.rsplit(".", 1)[0]
return path
def get_class(self) -> "type[AttentionBackend]":
"""Get the backend class (respects overrides).
Returns:
The backend class
Raises:
ImportError: If the backend class cannot be imported
ValueError: If Backend.CUSTOM is used without being registered
"""
return resolve_obj_by_qualname(self.get_path())
def is_overridden(self) -> bool:
"""Check if this backend has been overridden.
Returns:
True if the backend has a registered override
"""
return self in _OVERRIDES
def clear_override(self) -> None:
"""Clear any override for this backend, reverting to the default."""
_OVERRIDES.pop(self, None)
_OVERRIDES: dict[AttentionBackendEnum, str] = {}
def register_backend(
backend: AttentionBackendEnum, class_path: str | None = None
) -> Callable[[type], type]:
"""Register or override a backend implementation.
Args:
backend: The AttentionBackendEnum member to register
class_path: Optional class path. If not provided and used as
decorator, will be auto-generated from the class.
Returns:
Decorator function if class_path is None, otherwise a no-op
Examples:
# Override an existing backend
@register_backend(AttentionBackendEnum.FLASH_ATTN)
class MyCustomFlashAttn:
...
# Register a custom third-party backend
@register_backend(AttentionBackendEnum.CUSTOM)
class MyCustomBackend:
...
# Direct registration
register_backend(
AttentionBackendEnum.CUSTOM,
"my.module.MyCustomBackend"
)
"""
def decorator(cls: type) -> type:
_OVERRIDES[backend] = f"{cls.__module__}.{cls.__qualname__}"
return cls return cls
if class_path is not None:
_OVERRIDES[backend] = class_path
return lambda x: x
return decorator return decorator
def backend_to_class_str(backend: _Backend) -> str: # Backwards compatibility alias for plugins
"""Get the backend class string class _BackendMeta(type):
"""Metaclass to provide deprecation warnings when accessing _Backend."""
Args: def __getattribute__(cls, name: str):
backend: The backend enum value if name not in ("__class__", "__mro__", "__name__"):
logger.warning(
"_Backend has been renamed to AttentionBackendEnum. "
"Please update your code to use AttentionBackendEnum instead. "
"_Backend will be removed in a future release."
)
return getattr(AttentionBackendEnum, name)
Returns: def __getitem__(cls, name: str):
The backend class string logger.warning(
"_Backend has been renamed to AttentionBackendEnum. "
"Please update your code to use AttentionBackendEnum instead. "
"_Backend will be removed in a future release."
)
return AttentionBackendEnum[name]
class _Backend(metaclass=_BackendMeta):
"""Deprecated: Use AttentionBackendEnum instead.
This class is provided for backwards compatibility with plugins
and will be removed in a future release.
""" """
return BACKEND_MAP[backend]
pass
def backend_to_class(backend: _Backend) -> type:
"""Get the backend class.
Args:
backend: The backend enum value
Returns:
The backend class
"""
backend_class_name = backend_to_class_str(backend)
return resolve_obj_by_qualname(backend_class_name)
def backend_name_to_enum(backend_name: str) -> _Backend | None:
"""
Convert a string backend name to a _Backend enum value.
Returns:
_Backend: enum value if backend_name is a valid in-tree type
None: otherwise it's an invalid in-tree type or an out-of-tree platform
is loaded.
"""
assert backend_name is not None
return _Backend[backend_name] if backend_name in _Backend.__members__ else None

View File

@ -12,7 +12,7 @@ import torch.nn.functional as F
import vllm.envs as envs import vllm.envs as envs
from vllm.attention import AttentionType from vllm.attention import AttentionType
from vllm.attention.backends.abstract import AttentionBackend, MLAAttentionImpl from vllm.attention.backends.abstract import AttentionBackend, MLAAttentionImpl
from vllm.attention.backends.registry import _Backend, backend_name_to_enum from vllm.attention.backends.registry import AttentionBackendEnum
from vllm.attention.selector import get_attn_backend from vllm.attention.selector import get_attn_backend
from vllm.attention.utils.kv_sharing_utils import validate_kv_sharing_target from vllm.attention.utils.kv_sharing_utils import validate_kv_sharing_target
from vllm.config import CacheConfig, get_current_vllm_config from vllm.config import CacheConfig, get_current_vllm_config
@ -99,40 +99,44 @@ def check_upstream_fa_availability(dtype: torch.dtype):
def maybe_get_vit_flash_attn_backend( def maybe_get_vit_flash_attn_backend(
attn_backend: _Backend, attn_backend: AttentionBackendEnum,
use_upstream_fa: bool, use_upstream_fa: bool,
attn_backend_override: _Backend | None = None, attn_backend_override: AttentionBackendEnum | None = None,
) -> tuple[_Backend, Callable | None]: ) -> tuple[AttentionBackendEnum, Callable | None]:
if current_platform.is_rocm(): if current_platform.is_rocm():
if envs.VLLM_ROCM_USE_AITER and envs.VLLM_ROCM_USE_AITER_MHA and on_gfx9(): if envs.VLLM_ROCM_USE_AITER and envs.VLLM_ROCM_USE_AITER_MHA and on_gfx9():
attn_backend = _Backend.ROCM_AITER_FA attn_backend = AttentionBackendEnum.ROCM_AITER_FA
elif ( elif (
check_upstream_fa_availability(torch.get_default_dtype()) check_upstream_fa_availability(torch.get_default_dtype())
and on_gfx9() and on_gfx9()
and attn_backend_override is None and attn_backend_override is None
): ):
attn_backend = _Backend.FLASH_ATTN attn_backend = AttentionBackendEnum.FLASH_ATTN
use_upstream_fa = True use_upstream_fa = True
else: else:
return _Backend.TORCH_SDPA, None return AttentionBackendEnum.TORCH_SDPA, None
elif current_platform.is_cuda(): elif current_platform.is_cuda():
if attn_backend != _Backend.FLASH_ATTN and check_upstream_fa_availability( if (
torch.get_default_dtype() attn_backend != AttentionBackendEnum.FLASH_ATTN
and check_upstream_fa_availability(torch.get_default_dtype())
): ):
attn_backend = _Backend.FLASH_ATTN attn_backend = AttentionBackendEnum.FLASH_ATTN
use_upstream_fa = True use_upstream_fa = True
elif current_platform.is_xpu(): elif current_platform.is_xpu():
assert attn_backend == _Backend.FLASH_ATTN, ( assert attn_backend == AttentionBackendEnum.FLASH_ATTN, (
"XPU platform only supports FLASH_ATTN as vision attention backend." "XPU platform only supports FLASH_ATTN as vision attention backend."
) )
use_upstream_fa = False use_upstream_fa = False
else: else:
return _Backend.TORCH_SDPA, None return AttentionBackendEnum.TORCH_SDPA, None
if attn_backend in {_Backend.FLASH_ATTN, _Backend.ROCM_AITER_FA}: if attn_backend in {
if attn_backend == _Backend.ROCM_AITER_FA: AttentionBackendEnum.FLASH_ATTN,
AttentionBackendEnum.ROCM_AITER_FA,
}:
if attn_backend == AttentionBackendEnum.ROCM_AITER_FA:
from aiter import flash_attn_varlen_func from aiter import flash_attn_varlen_func
else: else:
if use_upstream_fa: if use_upstream_fa:
@ -309,7 +313,7 @@ class Attention(nn.Module, AttentionLayerBase):
kv_sharing_target_layer_name, kv_sharing_target_layer_name,
**extra_impl_args, **extra_impl_args,
) )
self.backend = backend_name_to_enum(self.attn_backend.get_name()) self.backend = AttentionBackendEnum[self.attn_backend.get_name()]
self.dtype = dtype self.dtype = dtype
# For cuda-alike (CUDA and ROCM) and cpu platforms, we control how # For cuda-alike (CUDA and ROCM) and cpu platforms, we control how
@ -530,13 +534,13 @@ class MultiHeadAttention(nn.Module):
backend backend
if backend if backend
in { in {
_Backend.TORCH_SDPA, AttentionBackendEnum.TORCH_SDPA,
_Backend.XFORMERS, AttentionBackendEnum.XFORMERS,
_Backend.PALLAS, AttentionBackendEnum.PALLAS,
_Backend.ROCM_AITER_FA, AttentionBackendEnum.ROCM_AITER_FA,
_Backend.FLASH_ATTN, AttentionBackendEnum.FLASH_ATTN,
} }
else _Backend.TORCH_SDPA else AttentionBackendEnum.TORCH_SDPA
) )
self.attn_backend, self._flash_attn_varlen_func = ( self.attn_backend, self._flash_attn_varlen_func = (
@ -547,17 +551,23 @@ class MultiHeadAttention(nn.Module):
) )
) )
if self.attn_backend == _Backend.XFORMERS and not check_xformers_availability(): if (
self.attn_backend = _Backend.TORCH_SDPA self.attn_backend == AttentionBackendEnum.XFORMERS
and not check_xformers_availability()
):
self.attn_backend = AttentionBackendEnum.TORCH_SDPA
self.is_flash_attn_backend = self.attn_backend in { self.is_flash_attn_backend = self.attn_backend in {
_Backend.FLASH_ATTN, AttentionBackendEnum.FLASH_ATTN,
_Backend.ROCM_AITER_FA, AttentionBackendEnum.ROCM_AITER_FA,
} }
# this condition is just to make sure that the # this condition is just to make sure that the
# use_upstream_fa in the log is correct # use_upstream_fa in the log is correct
if current_platform.is_rocm() and self.attn_backend == _Backend.FLASH_ATTN: if (
current_platform.is_rocm()
and self.attn_backend == AttentionBackendEnum.FLASH_ATTN
):
use_upstream_fa = True use_upstream_fa = True
logger.info_once( logger.info_once(
@ -606,17 +616,17 @@ class MultiHeadAttention(nn.Module):
max_seqlen_k=kv_len, max_seqlen_k=kv_len,
softmax_scale=self.scale, softmax_scale=self.scale,
) )
elif self.attn_backend == _Backend.XFORMERS: elif self.attn_backend == AttentionBackendEnum.XFORMERS:
from xformers import ops as xops from xformers import ops as xops
out = xops.memory_efficient_attention_forward( out = xops.memory_efficient_attention_forward(
query, key, value, scale=self.scale query, key, value, scale=self.scale
) )
elif self.attn_backend == _Backend.TORCH_SDPA: elif self.attn_backend == AttentionBackendEnum.TORCH_SDPA:
query, key, value = (x.transpose(1, 2) for x in (query, key, value)) query, key, value = (x.transpose(1, 2) for x in (query, key, value))
out = F.scaled_dot_product_attention(query, key, value, scale=self.scale) out = F.scaled_dot_product_attention(query, key, value, scale=self.scale)
out = out.transpose(1, 2) out = out.transpose(1, 2)
elif self.attn_backend == _Backend.PALLAS: elif self.attn_backend == AttentionBackendEnum.PALLAS:
query, key, value = (x.transpose(1, 2) for x in (query, key, value)) query, key, value = (x.transpose(1, 2) for x in (query, key, value))
from torch_xla.experimental.custom_kernel import flash_attention from torch_xla.experimental.custom_kernel import flash_attention

View File

@ -4,14 +4,15 @@
import os import os
from collections.abc import Generator from collections.abc import Generator
from contextlib import contextmanager from contextlib import contextmanager
from dataclasses import dataclass
from functools import cache from functools import cache
from typing import cast, get_args
import torch import torch
import vllm.envs as envs import vllm.envs as envs
from vllm.attention.backends.abstract import AttentionBackend from vllm.attention.backends.abstract import AttentionBackend
from vllm.attention.backends.registry import _Backend, backend_name_to_enum from vllm.attention.backends.registry import AttentionBackendEnum
from vllm.config.cache import CacheDType
from vllm.logger import init_logger from vllm.logger import init_logger
from vllm.utils import STR_BACKEND_ENV_VAR from vllm.utils import STR_BACKEND_ENV_VAR
from vllm.utils.import_utils import resolve_obj_by_qualname from vllm.utils.import_utils import resolve_obj_by_qualname
@ -19,18 +20,18 @@ from vllm.utils.import_utils import resolve_obj_by_qualname
logger = init_logger(__name__) logger = init_logger(__name__)
def get_env_variable_attn_backend() -> _Backend | None: def get_env_variable_attn_backend() -> AttentionBackendEnum | None:
""" """
Get the backend override specified by the vLLM attention Get the backend override specified by the vLLM attention
backend environment variable, if one is specified. backend environment variable, if one is specified.
Returns: Returns:
* _Backend enum value if an override is specified * AttentionBackendEnum value if an override is specified
* None otherwise * None otherwise
""" """
backend_name = os.environ.get(STR_BACKEND_ENV_VAR) backend_name = os.environ.get(STR_BACKEND_ENV_VAR)
return None if backend_name is None else backend_name_to_enum(backend_name) return None if backend_name is None else AttentionBackendEnum[backend_name]
# Global state allows a particular choice of backend # Global state allows a particular choice of backend
@ -40,10 +41,10 @@ def get_env_variable_attn_backend() -> _Backend | None:
# #
# THIS SELECTION TAKES PRECEDENCE OVER THE # THIS SELECTION TAKES PRECEDENCE OVER THE
# VLLM_ATTENTION_BACKEND ENVIRONMENT VARIABLE # VLLM_ATTENTION_BACKEND ENVIRONMENT VARIABLE
forced_attn_backend: _Backend | None = None forced_attn_backend: AttentionBackendEnum | None = None
def global_force_attn_backend(attn_backend: _Backend | None) -> None: def global_force_attn_backend(attn_backend: AttentionBackendEnum | None) -> None:
""" """
Force all attention operations to use a specified backend. Force all attention operations to use a specified backend.
@ -58,7 +59,7 @@ def global_force_attn_backend(attn_backend: _Backend | None) -> None:
forced_attn_backend = attn_backend forced_attn_backend = attn_backend
def get_global_forced_attn_backend() -> _Backend | None: def get_global_forced_attn_backend() -> AttentionBackendEnum | None:
""" """
Get the currently-forced choice of attention backend, Get the currently-forced choice of attention backend,
or None if auto-selection is currently enabled. or None if auto-selection is currently enabled.
@ -66,78 +67,28 @@ def get_global_forced_attn_backend() -> _Backend | None:
return forced_attn_backend return forced_attn_backend
@dataclass(frozen=True)
class _IsSupported:
can_import: bool
head_size: bool
dtype: bool
def __bool__(self) -> bool:
return self.can_import and self.head_size and self.dtype
def is_attn_backend_supported(
attn_backend: str | type[AttentionBackend],
head_size: int,
dtype: torch.dtype,
*,
allow_import_error: bool = True,
) -> _IsSupported:
if isinstance(attn_backend, str):
try:
attn_backend = resolve_obj_by_qualname(attn_backend)
except ImportError:
if not allow_import_error:
raise
return _IsSupported(can_import=False, head_size=False, dtype=False)
assert isinstance(attn_backend, type)
# TODO: Update the interface once V0 is removed
if get_supported_head_sizes := getattr(
attn_backend, "get_supported_head_sizes", None
):
is_head_size_supported = head_size in get_supported_head_sizes()
elif validate_head_size := getattr(attn_backend, "validate_head_size", None):
try:
validate_head_size(head_size)
is_head_size_supported = True
except Exception:
is_head_size_supported = False
else:
raise NotImplementedError(
f"{attn_backend.__name__} does not support head size validation"
)
if get_supported_dtypes := getattr(attn_backend, "get_supported_dtypes", None):
is_dtype_supported = dtype in get_supported_dtypes()
else:
raise NotImplementedError(
f"{attn_backend.__name__} does not support dtype validation"
)
return _IsSupported(
can_import=True,
head_size=is_head_size_supported,
dtype=is_dtype_supported,
)
def get_attn_backend( def get_attn_backend(
head_size: int, head_size: int,
dtype: torch.dtype, dtype: torch.dtype,
kv_cache_dtype: str | None, kv_cache_dtype: str | None,
block_size: int, block_size: int | None,
use_mla: bool = False, use_mla: bool = False,
has_sink: bool = False, has_sink: bool = False,
use_sparse: bool = False, use_sparse: bool = False,
) -> type[AttentionBackend]: ) -> type[AttentionBackend]:
"""Selects which attention backend to use and lazily imports it.""" """Selects which attention backend to use and lazily imports it."""
if kv_cache_dtype is not None:
valid_cache_dtypes = get_args(CacheDType)
assert kv_cache_dtype in valid_cache_dtypes, (
f"Invalid kv_cache_dtype: {kv_cache_dtype}. "
f"Valid values are: {valid_cache_dtypes}"
)
return _cached_get_attn_backend( return _cached_get_attn_backend(
head_size=head_size, head_size=head_size,
dtype=dtype, dtype=dtype,
kv_cache_dtype=kv_cache_dtype, kv_cache_dtype=cast(CacheDType | None, kv_cache_dtype),
block_size=block_size, block_size=block_size,
use_mla=use_mla, use_mla=use_mla,
has_sink=has_sink, has_sink=has_sink,
@ -149,8 +100,8 @@ def get_attn_backend(
def _cached_get_attn_backend( def _cached_get_attn_backend(
head_size: int, head_size: int,
dtype: torch.dtype, dtype: torch.dtype,
kv_cache_dtype: str | None, kv_cache_dtype: CacheDType | None,
block_size: int, block_size: int | None,
use_mla: bool = False, use_mla: bool = False,
has_sink: bool = False, has_sink: bool = False,
use_sparse: bool = False, use_sparse: bool = False,
@ -161,7 +112,9 @@ def _cached_get_attn_backend(
# THIS SELECTION OVERRIDES THE VLLM_ATTENTION_BACKEND # THIS SELECTION OVERRIDES THE VLLM_ATTENTION_BACKEND
# ENVIRONMENT VARIABLE. # ENVIRONMENT VARIABLE.
selected_backend = None selected_backend = None
backend_by_global_setting: _Backend | None = get_global_forced_attn_backend() backend_by_global_setting: AttentionBackendEnum | None = (
get_global_forced_attn_backend()
)
if backend_by_global_setting is not None: if backend_by_global_setting is not None:
selected_backend = backend_by_global_setting selected_backend = backend_by_global_setting
else: else:
@ -177,12 +130,13 @@ def _cached_get_attn_backend(
STR_BACKEND_ENV_VAR, STR_BACKEND_ENV_VAR,
) )
backend_by_env_var = backend_by_env_var.removesuffix("_VLLM_V1") backend_by_env_var = backend_by_env_var.removesuffix("_VLLM_V1")
selected_backend = backend_name_to_enum(backend_by_env_var) try:
if selected_backend is None: selected_backend = AttentionBackendEnum[backend_by_env_var]
except KeyError as e:
raise ValueError( raise ValueError(
f"Invalid attention backend: '{backend_by_env_var}'. " f"Invalid attention backend: '{backend_by_env_var}'. Valid "
f"Valid backends are: {list(_Backend.__members__.keys())}" f"backends are: {list(AttentionBackendEnum.__members__.keys())}"
) ) from e
# get device-specific attn_backend # get device-specific attn_backend
from vllm.platforms import current_platform from vllm.platforms import current_platform
@ -202,12 +156,26 @@ def _cached_get_attn_backend(
raise ValueError( raise ValueError(
f"Invalid attention backend for {current_platform.device_name}" f"Invalid attention backend for {current_platform.device_name}"
) )
return resolve_obj_by_qualname(attention_cls) backend = resolve_obj_by_qualname(attention_cls)
# Adjust kv cache layout if the selected backend requires a specific one
required_layout = backend.get_required_kv_cache_layout()
if required_layout is not None:
from vllm.v1.attention.backends.utils import set_kv_cache_layout
set_kv_cache_layout(required_layout)
logger.info(
"Using %s KV cache layout for %s backend.",
required_layout,
backend.get_name(),
)
return backend
@contextmanager @contextmanager
def global_force_attn_backend_context_manager( def global_force_attn_backend_context_manager(
attn_backend: _Backend, attn_backend: AttentionBackendEnum,
) -> Generator[None, None, None]: ) -> Generator[None, None, None]:
""" """
Globally force a vLLM attention backend override within a Globally force a vLLM attention backend override within a

View File

@ -21,7 +21,15 @@ else:
logger = init_logger(__name__) logger = init_logger(__name__)
BlockSize = Literal[1, 8, 16, 32, 64, 128, 256] BlockSize = Literal[1, 8, 16, 32, 64, 128, 256]
CacheDType = Literal["auto", "bfloat16", "fp8", "fp8_e4m3", "fp8_e5m2", "fp8_inc"] CacheDType = Literal[
"auto",
"bfloat16",
"fp8",
"fp8_e4m3",
"fp8_e5m2",
"fp8_inc",
"fp8_ds_mla",
]
MambaDType = Literal["auto", "float32"] MambaDType = Literal["auto", "float32"]
PrefixCachingHashAlgo = Literal["sha256", "sha256_cbor"] PrefixCachingHashAlgo = Literal["sha256", "sha256_cbor"]
KVOffloadingBackend = Literal["native", "lmcache"] KVOffloadingBackend = Literal["native", "lmcache"]

View File

@ -45,7 +45,7 @@ if TYPE_CHECKING:
import vllm.model_executor.layers.quantization as me_quant import vllm.model_executor.layers.quantization as me_quant
import vllm.model_executor.models as me_models import vllm.model_executor.models as me_models
from vllm.attention.backends.registry import _Backend from vllm.attention.backends.registry import AttentionBackendEnum
from vllm.config.load import LoadConfig from vllm.config.load import LoadConfig
from vllm.config.parallel import ParallelConfig from vllm.config.parallel import ParallelConfig
from vllm.model_executor.layers.quantization import QuantizationMethods from vllm.model_executor.layers.quantization import QuantizationMethods
@ -53,7 +53,7 @@ if TYPE_CHECKING:
else: else:
PretrainedConfig = Any PretrainedConfig = Any
_Backend = Any AttentionBackendEnum = Any
me_quant = LazyLoader( me_quant = LazyLoader(
"model_executor", globals(), "vllm.model_executor.layers.quantization" "model_executor", globals(), "vllm.model_executor.layers.quantization"
) )
@ -302,7 +302,7 @@ class ModelConfig:
mm_processor_cache_type: InitVar[MMCacheType | None] = None mm_processor_cache_type: InitVar[MMCacheType | None] = None
mm_shm_cache_max_object_size_mb: InitVar[int | None] = None mm_shm_cache_max_object_size_mb: InitVar[int | None] = None
mm_encoder_tp_mode: InitVar[MMEncoderTPMode | None] = None mm_encoder_tp_mode: InitVar[MMEncoderTPMode | None] = None
mm_encoder_attn_backend: InitVar[_Backend | str | None] = None mm_encoder_attn_backend: InitVar[AttentionBackendEnum | str | None] = None
interleave_mm_strings: InitVar[bool | None] = None interleave_mm_strings: InitVar[bool | None] = None
skip_mm_profiling: InitVar[bool | None] = None skip_mm_profiling: InitVar[bool | None] = None
video_pruning_rate: InitVar[float | None] = None video_pruning_rate: InitVar[float | None] = None
@ -420,7 +420,7 @@ class ModelConfig:
mm_processor_cache_type: MMCacheType | None, mm_processor_cache_type: MMCacheType | None,
mm_shm_cache_max_object_size_mb: int | None, mm_shm_cache_max_object_size_mb: int | None,
mm_encoder_tp_mode: MMEncoderTPMode | None, mm_encoder_tp_mode: MMEncoderTPMode | None,
mm_encoder_attn_backend: _Backend | str | None, mm_encoder_attn_backend: AttentionBackendEnum | str | None,
interleave_mm_strings: bool | None, interleave_mm_strings: bool | None,
skip_mm_profiling: bool | None, skip_mm_profiling: bool | None,
video_pruning_rate: float | None, video_pruning_rate: float | None,

View File

@ -11,9 +11,9 @@ from pydantic.dataclasses import dataclass
from vllm.config.utils import config from vllm.config.utils import config
if TYPE_CHECKING: if TYPE_CHECKING:
from vllm.attention.backends.registry import _Backend from vllm.attention.backends.registry import AttentionBackendEnum
else: else:
_Backend = Any AttentionBackendEnum = Any
@dataclass @dataclass
@ -125,10 +125,10 @@ class MultiModalConfig:
DP (which is controlled by `--data-parallel-size`). DP (which is controlled by `--data-parallel-size`).
This is only supported on a per-model basis and falls back to This is only supported on a per-model basis and falls back to
`"weights"` if the encoder does not support DP.""" `"weights"` if the encoder does not support DP."""
mm_encoder_attn_backend: _Backend | None = None mm_encoder_attn_backend: AttentionBackendEnum | None = None
"""Optional override for the multi-modal encoder attention backend when """Optional override for the multi-modal encoder attention backend when
using vision transformers. Accepts any value from using vision transformers. Accepts any value from
`vllm.attention.backends.registry._Backend` (e.g. `FLASH_ATTN`).""" `vllm.attention.backends.registry.AttentionBackendEnum` (e.g. `FLASH_ATTN`)."""
interleave_mm_strings: bool = False interleave_mm_strings: bool = False
"""Enable fully interleaved support for multimodal prompts, while using """Enable fully interleaved support for multimodal prompts, while using
--chat-template-content-format=string.""" --chat-template-content-format=string."""
@ -167,26 +167,16 @@ class MultiModalConfig:
@field_validator("mm_encoder_attn_backend", mode="before") @field_validator("mm_encoder_attn_backend", mode="before")
@classmethod @classmethod
def _validate_mm_encoder_attn_backend(cls, value: object) -> _Backend | None: def _validate_mm_encoder_attn_backend(
from vllm.attention.backends.registry import ( cls, value: str | AttentionBackendEnum | None
_Backend as BackendEnum, ) -> AttentionBackendEnum | None:
) if value is None or isinstance(value, AttentionBackendEnum):
from vllm.attention.backends.registry import (
backend_name_to_enum,
)
if value is None or isinstance(value, BackendEnum):
return value return value
if isinstance(value, str): assert isinstance(value, str), (
candidate = backend_name_to_enum(value.upper()) "mm_encoder_attn_backend must be a string or an AttentionBackendEnum."
if candidate is not None:
return candidate
valid_backends = ", ".join(sorted(BackendEnum.__members__.keys()))
raise ValueError(
f"Invalid mm encoder attention backend. Expected one of: {valid_backends}."
) )
return AttentionBackendEnum[value.upper()]
@model_validator(mode="after") @model_validator(mode="after")
def _validate_multimodal_config(self): def _validate_multimodal_config(self):

View File

@ -21,7 +21,7 @@ import torch
import zmq import zmq
from vllm import envs from vllm import envs
from vllm.attention.backends.registry import _Backend, backend_name_to_enum from vllm.attention.backends.registry import AttentionBackendEnum
from vllm.attention.selector import get_attn_backend from vllm.attention.selector import get_attn_backend
from vllm.config import VllmConfig from vllm.config import VllmConfig
from vllm.distributed.kv_transfer.kv_connector.v1.base import ( from vllm.distributed.kv_transfer.kv_connector.v1.base import (
@ -876,9 +876,9 @@ class NixlConnectorWorker:
use_mla=self.use_mla, use_mla=self.use_mla,
) )
self.backend_name = backend.get_name() self.backend_name = backend.get_name()
attn_backend = backend_name_to_enum(self.backend_name) attn_backend = AttentionBackendEnum[self.backend_name]
self._use_flashinfer = attn_backend == _Backend.FLASHINFER self._use_flashinfer = attn_backend == AttentionBackendEnum.FLASHINFER
self._use_pallas = attn_backend == _Backend.PALLAS self._use_pallas = attn_backend == AttentionBackendEnum.PALLAS
self.kv_cache_layout = get_kv_cache_layout() self.kv_cache_layout = get_kv_cache_layout()
self.host_buffer_kv_cache_layout = self.kv_cache_layout self.host_buffer_kv_cache_layout = self.kv_cache_layout
logger.debug("Detected attention backend %s", self.backend_name) logger.debug("Detected attention backend %s", self.backend_name)

View File

@ -32,7 +32,7 @@ from pydantic.fields import FieldInfo
from typing_extensions import TypeIs, deprecated from typing_extensions import TypeIs, deprecated
import vllm.envs as envs import vllm.envs as envs
from vllm.attention.backends.registry import _Backend from vllm.attention.backends.registry import AttentionBackendEnum
from vllm.config import ( from vllm.config import (
CacheConfig, CacheConfig,
CompilationConfig, CompilationConfig,
@ -462,7 +462,7 @@ class EngineArgs:
MultiModalConfig.mm_shm_cache_max_object_size_mb MultiModalConfig.mm_shm_cache_max_object_size_mb
) )
mm_encoder_tp_mode: MMEncoderTPMode = MultiModalConfig.mm_encoder_tp_mode mm_encoder_tp_mode: MMEncoderTPMode = MultiModalConfig.mm_encoder_tp_mode
mm_encoder_attn_backend: _Backend | str | None = ( mm_encoder_attn_backend: AttentionBackendEnum | str | None = (
MultiModalConfig.mm_encoder_attn_backend MultiModalConfig.mm_encoder_attn_backend
) )
io_processor_plugin: str | None = None io_processor_plugin: str | None = None

View File

@ -626,14 +626,14 @@ environment_variables: dict[str, Callable[[], Any]] = {
# - "FLASH_ATTN_MLA": use FlashAttention for MLA # - "FLASH_ATTN_MLA": use FlashAttention for MLA
# - "FLASHINFER_MLA": use FlashInfer for MLA # - "FLASHINFER_MLA": use FlashInfer for MLA
# - "CUTLASS_MLA": use CUTLASS for MLA # - "CUTLASS_MLA": use CUTLASS for MLA
# All possible options loaded dynamically from _Backend enum # All possible options loaded dynamically from AttentionBackendEnum
"VLLM_ATTENTION_BACKEND": env_with_choices( "VLLM_ATTENTION_BACKEND": env_with_choices(
"VLLM_ATTENTION_BACKEND", "VLLM_ATTENTION_BACKEND",
None, None,
lambda: list( lambda: list(
__import__( __import__(
"vllm.attention.backends.registry", fromlist=["_Backend"] "vllm.attention.backends.registry", fromlist=["AttentionBackendEnum"]
)._Backend.__members__.keys() ).AttentionBackendEnum.__members__.keys()
), ),
), ),
# If set, vllm will use flashinfer sampler # If set, vllm will use flashinfer sampler

View File

@ -9,7 +9,7 @@ import torch.nn.functional as F
from torch.nn import LayerNorm from torch.nn import LayerNorm
from transformers.models.qwen2_vl import Qwen2VLProcessor from transformers.models.qwen2_vl import Qwen2VLProcessor
from vllm.attention.backends.registry import _Backend from vllm.attention.backends.registry import AttentionBackendEnum
from vllm.attention.layer import ( from vllm.attention.layer import (
check_upstream_fa_availability, check_upstream_fa_availability,
maybe_get_vit_flash_attn_backend, maybe_get_vit_flash_attn_backend,
@ -256,7 +256,7 @@ class DotsVisionAttention(nn.Module):
quant_config: QuantizationConfig | None = None, quant_config: QuantizationConfig | None = None,
prefix: str = "", prefix: str = "",
use_data_parallel: bool = False, use_data_parallel: bool = False,
attn_backend_override: _Backend | None = None, attn_backend_override: AttentionBackendEnum | None = None,
) -> None: ) -> None:
super().__init__() super().__init__()
@ -303,17 +303,17 @@ class DotsVisionAttention(nn.Module):
) )
) )
if self.attn_backend not in { if self.attn_backend not in {
_Backend.FLASH_ATTN, AttentionBackendEnum.FLASH_ATTN,
_Backend.TORCH_SDPA, AttentionBackendEnum.TORCH_SDPA,
_Backend.XFORMERS, AttentionBackendEnum.XFORMERS,
_Backend.ROCM_AITER_FA, AttentionBackendEnum.ROCM_AITER_FA,
}: }:
raise RuntimeError( raise RuntimeError(
f"Unsupported vision attention backend: {self.attn_backend}" f"Unsupported vision attention backend: {self.attn_backend}"
) )
self.is_flash_attn_backend = self.attn_backend in { self.is_flash_attn_backend = self.attn_backend in {
_Backend.FLASH_ATTN, AttentionBackendEnum.FLASH_ATTN,
_Backend.ROCM_AITER_FA, AttentionBackendEnum.ROCM_AITER_FA,
} }
def forward( def forward(
@ -361,7 +361,7 @@ class DotsVisionAttention(nn.Module):
self.num_attention_heads_per_partition, self.num_attention_heads_per_partition,
self.hidden_size_per_attention_head, self.hidden_size_per_attention_head,
) )
elif self.attn_backend == _Backend.TORCH_SDPA: elif self.attn_backend == AttentionBackendEnum.TORCH_SDPA:
outputs = [] outputs = []
for i in range(1, len(cu_seqlens)): for i in range(1, len(cu_seqlens)):
s = int(cu_seqlens[i - 1]) s = int(cu_seqlens[i - 1])
@ -373,7 +373,7 @@ class DotsVisionAttention(nn.Module):
out_i = out_i.permute(0, 2, 1, 3) out_i = out_i.permute(0, 2, 1, 3)
outputs.append(out_i) outputs.append(out_i)
context_layer = torch.cat(outputs, dim=1) if outputs else q[:, :0] context_layer = torch.cat(outputs, dim=1) if outputs else q[:, :0]
elif self.attn_backend == _Backend.XFORMERS: elif self.attn_backend == AttentionBackendEnum.XFORMERS:
from xformers import ops as xops from xformers import ops as xops
from xformers.ops.fmha.attn_bias import BlockDiagonalMask from xformers.ops.fmha.attn_bias import BlockDiagonalMask
@ -514,7 +514,7 @@ class DotsVisionBlock(nn.Module):
quant_config: QuantizationConfig | None = None, quant_config: QuantizationConfig | None = None,
prefix: str = "", prefix: str = "",
use_data_parallel: bool = False, use_data_parallel: bool = False,
attn_backend_override: _Backend | None = None, attn_backend_override: AttentionBackendEnum | None = None,
): ):
super().__init__() super().__init__()
@ -567,7 +567,7 @@ class DotsVisionTransformer(nn.Module):
require_post_norm: bool | None = None, require_post_norm: bool | None = None,
prefix: str = "", prefix: str = "",
use_data_parallel: bool = False, use_data_parallel: bool = False,
attn_backend_override: _Backend | None = None, attn_backend_override: AttentionBackendEnum | None = None,
) -> None: ) -> None:
super().__init__() super().__init__()
self.config = config self.config = config
@ -582,10 +582,11 @@ class DotsVisionTransformer(nn.Module):
dtype=torch.get_default_dtype(), dtype=torch.get_default_dtype(),
attn_backend_override=attn_backend_override, attn_backend_override=attn_backend_override,
) )
if self.attn_backend != _Backend.FLASH_ATTN and check_upstream_fa_availability( if (
torch.get_default_dtype() self.attn_backend != AttentionBackendEnum.FLASH_ATTN
and check_upstream_fa_availability(torch.get_default_dtype())
): ):
self.attn_backend = _Backend.FLASH_ATTN self.attn_backend = AttentionBackendEnum.FLASH_ATTN
self.out_hidden_size = config.hidden_size self.out_hidden_size = config.hidden_size
# Keep blocks for compatibility with other vision towers # Keep blocks for compatibility with other vision towers
num_layers = ( num_layers = (
@ -666,11 +667,11 @@ class DotsVisionTransformer(nn.Module):
) -> tuple[int | None, list[int] | None]: ) -> tuple[int | None, list[int] | None]:
max_seqlen, seqlens = None, None max_seqlen, seqlens = None, None
if ( if (
self.attn_backend == _Backend.FLASH_ATTN self.attn_backend == AttentionBackendEnum.FLASH_ATTN
or self.attn_backend == _Backend.ROCM_AITER_FA or self.attn_backend == AttentionBackendEnum.ROCM_AITER_FA
): ):
max_seqlen = (cu_seqlens[1:] - cu_seqlens[:-1]).max().item() max_seqlen = (cu_seqlens[1:] - cu_seqlens[:-1]).max().item()
elif self.attn_backend == _Backend.XFORMERS: elif self.attn_backend == AttentionBackendEnum.XFORMERS:
seqlens = (cu_seqlens[1:] - cu_seqlens[:-1]).tolist() seqlens = (cu_seqlens[1:] - cu_seqlens[:-1]).tolist()
return max_seqlen, seqlens return max_seqlen, seqlens

View File

@ -36,7 +36,7 @@ import torch.nn.functional as F
from einops import rearrange, repeat from einops import rearrange, repeat
from transformers import BatchFeature, PretrainedConfig from transformers import BatchFeature, PretrainedConfig
from vllm.attention.backends.registry import _Backend from vllm.attention.backends.registry import AttentionBackendEnum
from vllm.attention.layer import ( from vllm.attention.layer import (
check_upstream_fa_availability, check_upstream_fa_availability,
maybe_get_vit_flash_attn_backend, maybe_get_vit_flash_attn_backend,
@ -164,7 +164,7 @@ class Ernie4_5_VisionAttention(nn.Module):
projection_size: int, projection_size: int,
quant_config: QuantizationConfig | None = None, quant_config: QuantizationConfig | None = None,
prefix: str = "", prefix: str = "",
attn_backend_override: _Backend | None = None, attn_backend_override: AttentionBackendEnum | None = None,
) -> None: ) -> None:
super().__init__() super().__init__()
# Per attention head and per partition values. # Per attention head and per partition values.
@ -211,17 +211,17 @@ class Ernie4_5_VisionAttention(nn.Module):
) )
if self.attn_backend not in { if self.attn_backend not in {
_Backend.FLASH_ATTN, AttentionBackendEnum.FLASH_ATTN,
_Backend.TORCH_SDPA, AttentionBackendEnum.TORCH_SDPA,
_Backend.XFORMERS, AttentionBackendEnum.XFORMERS,
_Backend.ROCM_AITER_FA, AttentionBackendEnum.ROCM_AITER_FA,
}: }:
raise RuntimeError( raise RuntimeError(
f"Ernie45-VL does not support {self.attn_backend} backend now." f"Ernie45-VL does not support {self.attn_backend} backend now."
) )
self.is_flash_attn_backend = self.attn_backend in { self.is_flash_attn_backend = self.attn_backend in {
_Backend.FLASH_ATTN, AttentionBackendEnum.FLASH_ATTN,
_Backend.ROCM_AITER_FA, AttentionBackendEnum.ROCM_AITER_FA,
} }
def split_qkv(self, qkv: torch.Tensor) -> tuple[torch.Tensor, ...]: def split_qkv(self, qkv: torch.Tensor) -> tuple[torch.Tensor, ...]:
@ -291,7 +291,7 @@ class Ernie4_5_VisionAttention(nn.Module):
context_layer = rearrange( context_layer = rearrange(
output, "(b s) h d -> s b (h d)", b=batch_size output, "(b s) h d -> s b (h d)", b=batch_size
).contiguous() ).contiguous()
elif self.attn_backend == _Backend.TORCH_SDPA: elif self.attn_backend == AttentionBackendEnum.TORCH_SDPA:
# Execute attention entry by entry for speed & less VRAM. # Execute attention entry by entry for speed & less VRAM.
outputs = [] outputs = []
for i in range(1, len(cu_seqlens)): for i in range(1, len(cu_seqlens)):
@ -310,7 +310,7 @@ class Ernie4_5_VisionAttention(nn.Module):
context_layer = rearrange( context_layer = rearrange(
context_layer, "b s h d -> s b (h d)" context_layer, "b s h d -> s b (h d)"
).contiguous() ).contiguous()
elif self.attn_backend == _Backend.XFORMERS: elif self.attn_backend == AttentionBackendEnum.XFORMERS:
from xformers import ops as xops from xformers import ops as xops
from xformers.ops.fmha.attn_bias import BlockDiagonalMask from xformers.ops.fmha.attn_bias import BlockDiagonalMask
@ -370,7 +370,7 @@ class Ernie4_5_VisionBlock(nn.Module):
norm_layer: Callable[[int], nn.Module] | None = None, norm_layer: Callable[[int], nn.Module] | None = None,
quant_config: QuantizationConfig | None = None, quant_config: QuantizationConfig | None = None,
prefix: str = "", prefix: str = "",
attn_backend_override: _Backend | None = None, attn_backend_override: AttentionBackendEnum | None = None,
) -> None: ) -> None:
super().__init__() super().__init__()
@ -463,7 +463,7 @@ class Ernie4_5_VisionTransformer(nn.Module):
norm_eps: float = 1e-6, norm_eps: float = 1e-6,
quant_config: QuantizationConfig | None = None, quant_config: QuantizationConfig | None = None,
prefix: str = "", prefix: str = "",
attn_backend_override: _Backend | None = None, attn_backend_override: AttentionBackendEnum | None = None,
) -> None: ) -> None:
super().__init__() super().__init__()
patch_size = vision_config.patch_size patch_size = vision_config.patch_size
@ -515,10 +515,11 @@ class Ernie4_5_VisionTransformer(nn.Module):
dtype=torch.get_default_dtype(), dtype=torch.get_default_dtype(),
attn_backend_override=attn_backend_override, attn_backend_override=attn_backend_override,
) )
if self.attn_backend != _Backend.FLASH_ATTN and check_upstream_fa_availability( if (
torch.get_default_dtype() self.attn_backend != AttentionBackendEnum.FLASH_ATTN
and check_upstream_fa_availability(torch.get_default_dtype())
): ):
self.attn_backend = _Backend.FLASH_ATTN self.attn_backend = AttentionBackendEnum.FLASH_ATTN
@property @property
def dtype(self) -> torch.dtype: def dtype(self) -> torch.dtype:
@ -565,11 +566,11 @@ class Ernie4_5_VisionTransformer(nn.Module):
) -> tuple[int | None, list[int] | None]: ) -> tuple[int | None, list[int] | None]:
max_seqlen, seqlens = None, None max_seqlen, seqlens = None, None
if ( if (
self.attn_backend == _Backend.FLASH_ATTN self.attn_backend == AttentionBackendEnum.FLASH_ATTN
or self.attn_backend == _Backend.ROCM_AITER_FA or self.attn_backend == AttentionBackendEnum.ROCM_AITER_FA
): ):
max_seqlen = (cu_seqlens[1:] - cu_seqlens[:-1]).max().item() max_seqlen = (cu_seqlens[1:] - cu_seqlens[:-1]).max().item()
elif self.attn_backend == _Backend.XFORMERS: elif self.attn_backend == AttentionBackendEnum.XFORMERS:
seqlens = (cu_seqlens[1:] - cu_seqlens[:-1]).tolist() seqlens = (cu_seqlens[1:] - cu_seqlens[:-1]).tolist()
return max_seqlen, seqlens return max_seqlen, seqlens

View File

@ -46,7 +46,7 @@ from transformers.models.glm4v.image_processing_glm4v import (
from transformers.models.glm4v.video_processing_glm4v import Glm4vVideoProcessor from transformers.models.glm4v.video_processing_glm4v import Glm4vVideoProcessor
from transformers.video_utils import VideoMetadata from transformers.video_utils import VideoMetadata
from vllm.attention.backends.registry import _Backend from vllm.attention.backends.registry import AttentionBackendEnum
from vllm.attention.layer import ( from vllm.attention.layer import (
check_upstream_fa_availability, check_upstream_fa_availability,
maybe_get_vit_flash_attn_backend, maybe_get_vit_flash_attn_backend,
@ -252,7 +252,7 @@ class Glm4vVisionAttention(nn.Module):
quant_config: QuantizationConfig | None = None, quant_config: QuantizationConfig | None = None,
prefix: str = "", prefix: str = "",
use_data_parallel: bool = False, use_data_parallel: bool = False,
attn_backend_override: _Backend | None = None, attn_backend_override: AttentionBackendEnum | None = None,
) -> None: ) -> None:
super().__init__() super().__init__()
# Per attention head and per partition values. # Per attention head and per partition values.
@ -306,18 +306,18 @@ class Glm4vVisionAttention(nn.Module):
) )
if self.attn_backend not in { if self.attn_backend not in {
_Backend.FLASH_ATTN, AttentionBackendEnum.FLASH_ATTN,
_Backend.TORCH_SDPA, AttentionBackendEnum.TORCH_SDPA,
_Backend.XFORMERS, AttentionBackendEnum.XFORMERS,
_Backend.ROCM_AITER_FA, AttentionBackendEnum.ROCM_AITER_FA,
}: }:
raise RuntimeError( raise RuntimeError(
f"GLM-4V does not support {self.attn_backend} backend now." f"GLM-4V does not support {self.attn_backend} backend now."
) )
self.is_flash_attn_backend = self.attn_backend in { self.is_flash_attn_backend = self.attn_backend in {
_Backend.FLASH_ATTN, AttentionBackendEnum.FLASH_ATTN,
_Backend.ROCM_AITER_FA, AttentionBackendEnum.ROCM_AITER_FA,
} }
def split_qkv(self, qkv: torch.Tensor) -> tuple[torch.Tensor, ...]: def split_qkv(self, qkv: torch.Tensor) -> tuple[torch.Tensor, ...]:
@ -377,7 +377,7 @@ class Glm4vVisionAttention(nn.Module):
context_layer = rearrange( context_layer = rearrange(
output, "(b s) h d -> s b (h d)", b=batch_size output, "(b s) h d -> s b (h d)", b=batch_size
).contiguous() ).contiguous()
elif self.attn_backend == _Backend.TORCH_SDPA: elif self.attn_backend == AttentionBackendEnum.TORCH_SDPA:
# Execute attention entry by entry for speed & less VRAM. # Execute attention entry by entry for speed & less VRAM.
outputs = [] outputs = []
for i in range(1, len(cu_seqlens)): for i in range(1, len(cu_seqlens)):
@ -396,7 +396,7 @@ class Glm4vVisionAttention(nn.Module):
context_layer = rearrange( context_layer = rearrange(
context_layer, "b s h d -> s b (h d)" context_layer, "b s h d -> s b (h d)"
).contiguous() ).contiguous()
elif self.attn_backend == _Backend.XFORMERS: elif self.attn_backend == AttentionBackendEnum.XFORMERS:
from xformers import ops as xops from xformers import ops as xops
from xformers.ops.fmha.attn_bias import BlockDiagonalMask from xformers.ops.fmha.attn_bias import BlockDiagonalMask
@ -425,7 +425,7 @@ class Glm4vVisionBlock(nn.Module):
quant_config: QuantizationConfig | None = None, quant_config: QuantizationConfig | None = None,
prefix: str = "", prefix: str = "",
use_data_parallel: bool = False, use_data_parallel: bool = False,
attn_backend_override: _Backend | None = None, attn_backend_override: AttentionBackendEnum | None = None,
) -> None: ) -> None:
super().__init__() super().__init__()
if norm_layer is None: if norm_layer is None:
@ -703,7 +703,7 @@ class Glm4vVisionTransformer(nn.Module):
quant_config: QuantizationConfig | None = None, quant_config: QuantizationConfig | None = None,
prefix: str = "", prefix: str = "",
use_data_parallel: bool = False, use_data_parallel: bool = False,
attn_backend_override: _Backend | None = None, attn_backend_override: AttentionBackendEnum | None = None,
) -> None: ) -> None:
super().__init__() super().__init__()
@ -772,10 +772,11 @@ class Glm4vVisionTransformer(nn.Module):
dtype=torch.get_default_dtype(), dtype=torch.get_default_dtype(),
attn_backend_override=attn_backend_override, attn_backend_override=attn_backend_override,
) )
if self.attn_backend != _Backend.FLASH_ATTN and check_upstream_fa_availability( if (
torch.get_default_dtype() self.attn_backend != AttentionBackendEnum.FLASH_ATTN
and check_upstream_fa_availability(torch.get_default_dtype())
): ):
self.attn_backend = _Backend.FLASH_ATTN self.attn_backend = AttentionBackendEnum.FLASH_ATTN
@property @property
def dtype(self) -> torch.dtype: def dtype(self) -> torch.dtype:
@ -824,8 +825,8 @@ class Glm4vVisionTransformer(nn.Module):
max_seqlen, seqlens = None, None max_seqlen, seqlens = None, None
seqlens = (cu_seqlens[1:] - cu_seqlens[:-1]).tolist() seqlens = (cu_seqlens[1:] - cu_seqlens[:-1]).tolist()
if ( if (
self.attn_backend == _Backend.FLASH_ATTN self.attn_backend == AttentionBackendEnum.FLASH_ATTN
or self.attn_backend == _Backend.ROCM_AITER_FA or self.attn_backend == AttentionBackendEnum.ROCM_AITER_FA
): ):
max_seqlen = (cu_seqlens[1:] - cu_seqlens[:-1]).max().item() max_seqlen = (cu_seqlens[1:] - cu_seqlens[:-1]).max().item()
return max_seqlen, seqlens return max_seqlen, seqlens

View File

@ -16,7 +16,7 @@ from transformers.feature_extraction_utils import BatchFeature
from transformers.modeling_outputs import BaseModelOutput, BaseModelOutputWithPooling from transformers.modeling_outputs import BaseModelOutput, BaseModelOutputWithPooling
from transformers.utils import torch_int from transformers.utils import torch_int
from vllm.attention.backends.registry import _Backend from vllm.attention.backends.registry import AttentionBackendEnum
from vllm.attention.layer import ( from vllm.attention.layer import (
maybe_get_vit_flash_attn_backend, maybe_get_vit_flash_attn_backend,
) )
@ -360,7 +360,7 @@ class KeyeSiglipAttention(nn.Module):
config: PretrainedConfig, config: PretrainedConfig,
quant_config: QuantizationConfig | None = None, quant_config: QuantizationConfig | None = None,
prefix: str = "", prefix: str = "",
attn_backend_override: _Backend | None = None, attn_backend_override: AttentionBackendEnum | None = None,
): ):
super().__init__() super().__init__()
self.config = config self.config = config
@ -414,17 +414,17 @@ class KeyeSiglipAttention(nn.Module):
) )
if self.attn_backend not in { if self.attn_backend not in {
_Backend.FLASH_ATTN, AttentionBackendEnum.FLASH_ATTN,
_Backend.XFORMERS, AttentionBackendEnum.XFORMERS,
_Backend.ROCM_AITER_FA, AttentionBackendEnum.ROCM_AITER_FA,
}: }:
raise RuntimeError( raise RuntimeError(
f"Keye-VL does not support {self.attn_backend} backend now." f"Keye-VL does not support {self.attn_backend} backend now."
) )
self.is_flash_attn_backend = self.attn_backend in { self.is_flash_attn_backend = self.attn_backend in {
_Backend.FLASH_ATTN, AttentionBackendEnum.FLASH_ATTN,
_Backend.ROCM_AITER_FA, AttentionBackendEnum.ROCM_AITER_FA,
} }
def forward( def forward(
@ -489,7 +489,7 @@ class KeyeSiglipAttention(nn.Module):
softmax_scale=self.scale, softmax_scale=self.scale,
) )
context_layer = rearrange(output, "(b s) ... -> b s ...", b=batch_size) context_layer = rearrange(output, "(b s) ... -> b s ...", b=batch_size)
elif self.attn_backend == _Backend.XFORMERS: elif self.attn_backend == AttentionBackendEnum.XFORMERS:
from xformers import ops as xops from xformers import ops as xops
from xformers.ops.fmha.attn_bias import BlockDiagonalMask from xformers.ops.fmha.attn_bias import BlockDiagonalMask
@ -536,7 +536,7 @@ class KeyeSiglipEncoderLayer(nn.Module):
config: PretrainedConfig, config: PretrainedConfig,
quant_config: QuantizationConfig | None = None, quant_config: QuantizationConfig | None = None,
prefix: str = "", prefix: str = "",
attn_backend_override: _Backend | None = None, attn_backend_override: AttentionBackendEnum | None = None,
): ):
super().__init__() super().__init__()
self.embed_dim = config.hidden_size self.embed_dim = config.hidden_size
@ -590,7 +590,7 @@ class KeyeSiglipEncoder(nn.Module):
config: PretrainedConfig, config: PretrainedConfig,
quant_config: QuantizationConfig | None = None, quant_config: QuantizationConfig | None = None,
prefix: str = "", prefix: str = "",
attn_backend_override: _Backend | None = None, attn_backend_override: AttentionBackendEnum | None = None,
): ):
super().__init__() super().__init__()
self.config = config self.config = config
@ -685,7 +685,7 @@ class KeyeSiglipVisionTransformer(nn.Module):
config: PretrainedConfig, config: PretrainedConfig,
quant_config: QuantizationConfig | None = None, quant_config: QuantizationConfig | None = None,
prefix: str = "", prefix: str = "",
attn_backend_override: _Backend | None = None, attn_backend_override: AttentionBackendEnum | None = None,
): ):
super().__init__() super().__init__()
self.config = config self.config = config
@ -768,7 +768,7 @@ class KeyeSiglipVisionModel(nn.Module):
config: PretrainedConfig, config: PretrainedConfig,
quant_config: QuantizationConfig | None = None, quant_config: QuantizationConfig | None = None,
prefix: str = "", prefix: str = "",
attn_backend_override: _Backend | None = None, attn_backend_override: AttentionBackendEnum | None = None,
): ):
super().__init__() super().__init__()

View File

@ -10,7 +10,7 @@ import torch
import torch.nn as nn import torch.nn as nn
from transformers import BaseImageProcessor, BatchFeature, PretrainedConfig from transformers import BaseImageProcessor, BatchFeature, PretrainedConfig
from vllm.attention.backends.registry import _Backend from vllm.attention.backends.registry import AttentionBackendEnum
from vllm.config import VllmConfig from vllm.config import VllmConfig
from vllm.config.multimodal import BaseDummyOptions from vllm.config.multimodal import BaseDummyOptions
from vllm.model_executor.layers.linear import ReplicatedLinear from vllm.model_executor.layers.linear import ReplicatedLinear
@ -106,7 +106,7 @@ class VisualTokenizer(torch.nn.Module):
quant_config: QuantizationConfig | None = None, quant_config: QuantizationConfig | None = None,
prefix: str = "", prefix: str = "",
use_data_parallel: bool = False, use_data_parallel: bool = False,
attn_backend_override: _Backend | None = None, attn_backend_override: AttentionBackendEnum | None = None,
): ):
super().__init__() super().__init__()
self.config = config self.config = config
@ -135,7 +135,7 @@ class VisualTokenizer(torch.nn.Module):
quant_config: QuantizationConfig | None = None, quant_config: QuantizationConfig | None = None,
prefix: str = "", prefix: str = "",
use_data_parallel: bool = False, use_data_parallel: bool = False,
attn_backend_override: _Backend | None = None, attn_backend_override: AttentionBackendEnum | None = None,
): ):
model_type = config.model_type model_type = config.model_type
if model_type == "siglip2_navit": if model_type == "siglip2_navit":

View File

@ -31,7 +31,7 @@ from transformers.modeling_outputs import (
) )
from transformers.utils import torch_int from transformers.utils import torch_int
from vllm.attention.backends.registry import _Backend from vllm.attention.backends.registry import AttentionBackendEnum
from vllm.attention.layer import ( from vllm.attention.layer import (
check_upstream_fa_availability, check_upstream_fa_availability,
maybe_get_vit_flash_attn_backend, maybe_get_vit_flash_attn_backend,
@ -580,8 +580,8 @@ class SiglipAttention(nn.Module):
projection_size: int, projection_size: int,
quant_config: QuantizationConfig | None = None, quant_config: QuantizationConfig | None = None,
prefix: str = "", prefix: str = "",
attn_backend: _Backend = _Backend.TORCH_SDPA, attn_backend: AttentionBackendEnum = AttentionBackendEnum.TORCH_SDPA,
attn_backend_override: _Backend | None = None, attn_backend_override: AttentionBackendEnum | None = None,
use_upstream_fa: bool = False, use_upstream_fa: bool = False,
) -> None: ) -> None:
super().__init__() super().__init__()
@ -621,8 +621,8 @@ class SiglipAttention(nn.Module):
) )
) )
self.is_flash_attn_backend = self.attn_backend in { self.is_flash_attn_backend = self.attn_backend in {
_Backend.FLASH_ATTN, AttentionBackendEnum.FLASH_ATTN,
_Backend.ROCM_AITER_FA, AttentionBackendEnum.ROCM_AITER_FA,
} }
def split_qkv(self, qkv: torch.Tensor) -> tuple[torch.Tensor, ...]: def split_qkv(self, qkv: torch.Tensor) -> tuple[torch.Tensor, ...]:
@ -680,10 +680,10 @@ class SiglipAttention(nn.Module):
cu_seqlens, cu_seqlens,
max_seqlen, max_seqlen,
batch_size, batch_size,
self.attn_backend == _Backend.ROCM_AITER_FA, self.attn_backend == AttentionBackendEnum.ROCM_AITER_FA,
self.use_upstream_fa, self.use_upstream_fa,
) )
elif self.attn_backend == _Backend.TORCH_SDPA: elif self.attn_backend == AttentionBackendEnum.TORCH_SDPA:
outputs = [] outputs = []
for i in range(1, len(cu_seqlens)): for i in range(1, len(cu_seqlens)):
start_idx = cu_seqlens[i - 1] start_idx = cu_seqlens[i - 1]
@ -702,7 +702,7 @@ class SiglipAttention(nn.Module):
context_layer = rearrange( context_layer = rearrange(
context_layer, "b s h d -> s b (h d)" context_layer, "b s h d -> s b (h d)"
).contiguous() ).contiguous()
elif self.attn_backend == _Backend.XFORMERS: elif self.attn_backend == AttentionBackendEnum.XFORMERS:
if seqlens is None: if seqlens is None:
raise ValueError("xFormers attention backend requires seqlens tensor.") raise ValueError("xFormers attention backend requires seqlens tensor.")
context_layer = vit_xformers_attn_wrapper(q, k, v, seqlens) context_layer = vit_xformers_attn_wrapper(q, k, v, seqlens)
@ -786,8 +786,8 @@ class SiglipEncoderLayer(nn.Module):
quant_config: QuantizationConfig | None = None, quant_config: QuantizationConfig | None = None,
prefix: str = "", prefix: str = "",
*, *,
attn_backend: _Backend = _Backend.TORCH_SDPA, attn_backend: AttentionBackendEnum = AttentionBackendEnum.TORCH_SDPA,
attn_backend_override: _Backend | None = None, attn_backend_override: AttentionBackendEnum | None = None,
use_upstream_fa: bool = False, use_upstream_fa: bool = False,
): ):
super().__init__() super().__init__()
@ -847,7 +847,7 @@ class SiglipEncoder(nn.Module):
config: PretrainedConfig, config: PretrainedConfig,
quant_config: QuantizationConfig | None = None, quant_config: QuantizationConfig | None = None,
prefix: str = "", prefix: str = "",
attn_backend_override: _Backend | None = None, attn_backend_override: AttentionBackendEnum | None = None,
): ):
super().__init__() super().__init__()
self.config = config self.config = config
@ -861,16 +861,16 @@ class SiglipEncoder(nn.Module):
) )
self.use_upstream_fa = False self.use_upstream_fa = False
if self.attn_backend not in { if self.attn_backend not in {
_Backend.FLASH_ATTN, AttentionBackendEnum.FLASH_ATTN,
_Backend.ROCM_AITER_FA, AttentionBackendEnum.ROCM_AITER_FA,
} and check_upstream_fa_availability(torch.get_default_dtype()): } and check_upstream_fa_availability(torch.get_default_dtype()):
self.attn_backend = _Backend.FLASH_ATTN self.attn_backend = AttentionBackendEnum.FLASH_ATTN
self.use_upstream_fa = True self.use_upstream_fa = True
if self.attn_backend not in { if self.attn_backend not in {
_Backend.FLASH_ATTN, AttentionBackendEnum.FLASH_ATTN,
_Backend.TORCH_SDPA, AttentionBackendEnum.TORCH_SDPA,
_Backend.XFORMERS, AttentionBackendEnum.XFORMERS,
_Backend.ROCM_AITER_FA, AttentionBackendEnum.ROCM_AITER_FA,
}: }:
raise RuntimeError( raise RuntimeError(
f"PaddleOCR-VL does not support {self.attn_backend} backend now." f"PaddleOCR-VL does not support {self.attn_backend} backend now."
@ -943,9 +943,12 @@ class SiglipEncoder(nn.Module):
max_seqlen = None max_seqlen = None
seqlens = None seqlens = None
if self.attn_backend in {_Backend.FLASH_ATTN, _Backend.ROCM_AITER_FA}: if self.attn_backend in {
AttentionBackendEnum.FLASH_ATTN,
AttentionBackendEnum.ROCM_AITER_FA,
}:
max_seqlen = (cu_seqlens[1:] - cu_seqlens[:-1]).max() max_seqlen = (cu_seqlens[1:] - cu_seqlens[:-1]).max()
elif self.attn_backend == _Backend.XFORMERS: elif self.attn_backend == AttentionBackendEnum.XFORMERS:
seqlens = cu_seqlens[1:] - cu_seqlens[:-1] seqlens = cu_seqlens[1:] - cu_seqlens[:-1]
hidden_states = inputs_embeds hidden_states = inputs_embeds
@ -966,7 +969,7 @@ class SiglipVisionTransformer(nn.Module):
config: PretrainedConfig, config: PretrainedConfig,
quant_config: QuantizationConfig | None = None, quant_config: QuantizationConfig | None = None,
prefix: str = "", prefix: str = "",
attn_backend_override: _Backend | None = None, attn_backend_override: AttentionBackendEnum | None = None,
): ):
super().__init__() super().__init__()
self.config = config self.config = config
@ -1016,7 +1019,7 @@ class SiglipVisionModel(nn.Module):
config, config,
quant_config: QuantizationConfig | None = None, quant_config: QuantizationConfig | None = None,
prefix: str = "", prefix: str = "",
attn_backend_override: _Backend | None = None, attn_backend_override: AttentionBackendEnum | None = None,
): ):
super().__init__() super().__init__()

View File

@ -42,7 +42,7 @@ from transformers.models.qwen2_5_vl.configuration_qwen2_5_vl import (
Qwen2_5_VLVisionConfig, Qwen2_5_VLVisionConfig,
) )
from vllm.attention.backends.registry import _Backend from vllm.attention.backends.registry import AttentionBackendEnum
from vllm.attention.layer import maybe_get_vit_flash_attn_backend from vllm.attention.layer import maybe_get_vit_flash_attn_backend
from vllm.attention.ops.vit_attn_wrappers import ( from vllm.attention.ops.vit_attn_wrappers import (
vit_flash_attn_wrapper, vit_flash_attn_wrapper,
@ -315,9 +315,9 @@ class Qwen2_5_VisionAttention(nn.Module):
quant_config: QuantizationConfig | None = None, quant_config: QuantizationConfig | None = None,
prefix: str = "", prefix: str = "",
use_data_parallel: bool = False, use_data_parallel: bool = False,
attn_backend: _Backend = _Backend.TORCH_SDPA, attn_backend: AttentionBackendEnum = AttentionBackendEnum.TORCH_SDPA,
use_upstream_fa: bool = False, use_upstream_fa: bool = False,
attn_backend_override: _Backend | None = None, attn_backend_override: AttentionBackendEnum | None = None,
) -> None: ) -> None:
super().__init__() super().__init__()
# Per attention head and per partition values. # Per attention head and per partition values.
@ -364,13 +364,16 @@ class Qwen2_5_VisionAttention(nn.Module):
# On ROCm with FLASH_ATTN backend, upstream flash_attn is used # On ROCm with FLASH_ATTN backend, upstream flash_attn is used
from vllm.platforms import current_platform from vllm.platforms import current_platform
if current_platform.is_rocm() and self.attn_backend == _Backend.FLASH_ATTN: if (
current_platform.is_rocm()
and self.attn_backend == AttentionBackendEnum.FLASH_ATTN
):
self.use_upstream_fa = True self.use_upstream_fa = True
if current_platform.is_xpu(): if current_platform.is_xpu():
self.use_upstream_fa = False self.use_upstream_fa = False
self.is_flash_attn_backend = self.attn_backend in { self.is_flash_attn_backend = self.attn_backend in {
_Backend.FLASH_ATTN, AttentionBackendEnum.FLASH_ATTN,
_Backend.ROCM_AITER_FA, AttentionBackendEnum.ROCM_AITER_FA,
} }
def split_qkv(self, qkv: torch.Tensor) -> tuple[torch.Tensor, ...]: def split_qkv(self, qkv: torch.Tensor) -> tuple[torch.Tensor, ...]:
@ -431,10 +434,10 @@ class Qwen2_5_VisionAttention(nn.Module):
cu_seqlens, cu_seqlens,
max_seqlen, max_seqlen,
batch_size, batch_size,
self.attn_backend == _Backend.ROCM_AITER_FA, self.attn_backend == AttentionBackendEnum.ROCM_AITER_FA,
self.use_upstream_fa, self.use_upstream_fa,
) )
elif self.attn_backend == _Backend.TORCH_SDPA: elif self.attn_backend == AttentionBackendEnum.TORCH_SDPA:
# Execute attention entry by entry for speed & less VRAM. # Execute attention entry by entry for speed & less VRAM.
from vllm.platforms import current_platform from vllm.platforms import current_platform
@ -450,7 +453,7 @@ class Qwen2_5_VisionAttention(nn.Module):
v, v,
cu_seqlens, cu_seqlens,
) )
elif self.attn_backend == _Backend.XFORMERS: elif self.attn_backend == AttentionBackendEnum.XFORMERS:
context_layer = vit_xformers_attn_wrapper(q, k, v, seqlens) context_layer = vit_xformers_attn_wrapper(q, k, v, seqlens)
output, _ = self.proj(context_layer) output, _ = self.proj(context_layer)
@ -478,9 +481,9 @@ class Qwen2_5_VisionBlock(nn.Module):
quant_config: QuantizationConfig | None = None, quant_config: QuantizationConfig | None = None,
prefix: str = "", prefix: str = "",
use_data_parallel: bool = False, use_data_parallel: bool = False,
attn_backend: _Backend = _Backend.TORCH_SDPA, attn_backend: AttentionBackendEnum = AttentionBackendEnum.TORCH_SDPA,
use_upstream_fa: bool = False, use_upstream_fa: bool = False,
attn_backend_override: _Backend | None = None, attn_backend_override: AttentionBackendEnum | None = None,
) -> None: ) -> None:
super().__init__() super().__init__()
if norm_layer is None: if norm_layer is None:
@ -656,7 +659,7 @@ class Qwen2_5_VisionTransformer(nn.Module):
quant_config: QuantizationConfig | None = None, quant_config: QuantizationConfig | None = None,
prefix: str = "", prefix: str = "",
use_data_parallel: bool = False, use_data_parallel: bool = False,
attn_backend_override: _Backend | None = None, attn_backend_override: AttentionBackendEnum | None = None,
) -> None: ) -> None:
super().__init__() super().__init__()
@ -708,10 +711,10 @@ class Qwen2_5_VisionTransformer(nn.Module):
) )
if self.attn_backend not in { if self.attn_backend not in {
_Backend.FLASH_ATTN, AttentionBackendEnum.FLASH_ATTN,
_Backend.TORCH_SDPA, AttentionBackendEnum.TORCH_SDPA,
_Backend.XFORMERS, AttentionBackendEnum.XFORMERS,
_Backend.ROCM_AITER_FA, AttentionBackendEnum.ROCM_AITER_FA,
}: }:
raise RuntimeError( raise RuntimeError(
f"Qwen2.5-VL does not support {self.attn_backend} backend now." f"Qwen2.5-VL does not support {self.attn_backend} backend now."
@ -850,9 +853,12 @@ class Qwen2_5_VisionTransformer(nn.Module):
) -> tuple[torch.Tensor, torch.Tensor]: ) -> tuple[torch.Tensor, torch.Tensor]:
max_seqlen = torch.zeros([], device=cu_seqlens.device) max_seqlen = torch.zeros([], device=cu_seqlens.device)
seqlens = torch.zeros(1, device=cu_seqlens.device) seqlens = torch.zeros(1, device=cu_seqlens.device)
if self.attn_backend in {_Backend.FLASH_ATTN, _Backend.ROCM_AITER_FA}: if self.attn_backend in {
AttentionBackendEnum.FLASH_ATTN,
AttentionBackendEnum.ROCM_AITER_FA,
}:
max_seqlen = (cu_seqlens[1:] - cu_seqlens[:-1]).max() max_seqlen = (cu_seqlens[1:] - cu_seqlens[:-1]).max()
elif self.attn_backend == _Backend.XFORMERS: elif self.attn_backend == AttentionBackendEnum.XFORMERS:
seqlens = cu_seqlens[1:] - cu_seqlens[:-1] seqlens = cu_seqlens[1:] - cu_seqlens[:-1]
return max_seqlen, seqlens return max_seqlen, seqlens

View File

@ -43,7 +43,7 @@ from transformers.models.qwen2_vl.configuration_qwen2_vl import (
from transformers.models.qwen2_vl.image_processing_qwen2_vl import smart_resize from transformers.models.qwen2_vl.image_processing_qwen2_vl import smart_resize
from transformers.models.qwen2_vl.video_processing_qwen2_vl import Qwen2VLVideoProcessor from transformers.models.qwen2_vl.video_processing_qwen2_vl import Qwen2VLVideoProcessor
from vllm.attention.backends.registry import _Backend from vllm.attention.backends.registry import AttentionBackendEnum
from vllm.attention.layer import ( from vllm.attention.layer import (
check_upstream_fa_availability, check_upstream_fa_availability,
maybe_get_vit_flash_attn_backend, maybe_get_vit_flash_attn_backend,
@ -329,7 +329,7 @@ class Qwen2VisionAttention(nn.Module):
quant_config: QuantizationConfig | None = None, quant_config: QuantizationConfig | None = None,
prefix: str = "", prefix: str = "",
use_data_parallel: bool = False, use_data_parallel: bool = False,
attn_backend_override: _Backend | None = None, attn_backend_override: AttentionBackendEnum | None = None,
) -> None: ) -> None:
super().__init__() super().__init__()
# Per attention head and per partition values. # Per attention head and per partition values.
@ -378,18 +378,18 @@ class Qwen2VisionAttention(nn.Module):
) )
if self.attn_backend not in { if self.attn_backend not in {
_Backend.FLASH_ATTN, AttentionBackendEnum.FLASH_ATTN,
_Backend.TORCH_SDPA, AttentionBackendEnum.TORCH_SDPA,
_Backend.XFORMERS, AttentionBackendEnum.XFORMERS,
_Backend.ROCM_AITER_FA, AttentionBackendEnum.ROCM_AITER_FA,
}: }:
raise RuntimeError( raise RuntimeError(
f"Qwen2-VL does not support {self.attn_backend} backend now." f"Qwen2-VL does not support {self.attn_backend} backend now."
) )
self.is_flash_attn_backend = self.attn_backend in { self.is_flash_attn_backend = self.attn_backend in {
_Backend.FLASH_ATTN, AttentionBackendEnum.FLASH_ATTN,
_Backend.ROCM_AITER_FA, AttentionBackendEnum.ROCM_AITER_FA,
} }
def split_qkv(self, qkv: torch.Tensor) -> tuple[torch.Tensor, ...]: def split_qkv(self, qkv: torch.Tensor) -> tuple[torch.Tensor, ...]:
@ -460,7 +460,7 @@ class Qwen2VisionAttention(nn.Module):
context_layer = rearrange( context_layer = rearrange(
output, "(b s) h d -> s b (h d)", b=batch_size output, "(b s) h d -> s b (h d)", b=batch_size
).contiguous() ).contiguous()
elif self.attn_backend == _Backend.TORCH_SDPA: elif self.attn_backend == AttentionBackendEnum.TORCH_SDPA:
# Execute attention entry by entry for speed & less VRAM. # Execute attention entry by entry for speed & less VRAM.
from vllm.platforms import current_platform from vllm.platforms import current_platform
@ -485,7 +485,7 @@ class Qwen2VisionAttention(nn.Module):
context_layer = rearrange( context_layer = rearrange(
context_layer, "b s h d -> s b (h d)" context_layer, "b s h d -> s b (h d)"
).contiguous() ).contiguous()
elif self.attn_backend == _Backend.XFORMERS: elif self.attn_backend == AttentionBackendEnum.XFORMERS:
from xformers import ops as xops from xformers import ops as xops
from xformers.ops.fmha.attn_bias import BlockDiagonalMask from xformers.ops.fmha.attn_bias import BlockDiagonalMask
@ -515,7 +515,7 @@ class Qwen2VisionBlock(nn.Module):
quant_config: QuantizationConfig | None = None, quant_config: QuantizationConfig | None = None,
prefix: str = "", prefix: str = "",
use_data_parallel: bool = False, use_data_parallel: bool = False,
attn_backend_override: _Backend | None = None, attn_backend_override: AttentionBackendEnum | None = None,
) -> None: ) -> None:
super().__init__() super().__init__()
if norm_layer is None: if norm_layer is None:
@ -679,7 +679,7 @@ class Qwen2VisionTransformer(nn.Module):
quant_config: QuantizationConfig | None = None, quant_config: QuantizationConfig | None = None,
prefix: str = "", prefix: str = "",
use_data_parallel: bool = False, use_data_parallel: bool = False,
attn_backend_override: _Backend | None = None, attn_backend_override: AttentionBackendEnum | None = None,
) -> None: ) -> None:
super().__init__() super().__init__()
@ -739,10 +739,11 @@ class Qwen2VisionTransformer(nn.Module):
dtype=torch.get_default_dtype(), dtype=torch.get_default_dtype(),
attn_backend_override=attn_backend_override, attn_backend_override=attn_backend_override,
) )
if self.attn_backend != _Backend.FLASH_ATTN and check_upstream_fa_availability( if (
torch.get_default_dtype() self.attn_backend != AttentionBackendEnum.FLASH_ATTN
and check_upstream_fa_availability(torch.get_default_dtype())
): ):
self.attn_backend = _Backend.FLASH_ATTN self.attn_backend = AttentionBackendEnum.FLASH_ATTN
@property @property
def dtype(self) -> torch.dtype: def dtype(self) -> torch.dtype:
@ -789,9 +790,12 @@ class Qwen2VisionTransformer(nn.Module):
self, cu_seqlens: torch.Tensor self, cu_seqlens: torch.Tensor
) -> tuple[int | None, list[int] | None]: ) -> tuple[int | None, list[int] | None]:
max_seqlen, seqlens = None, None max_seqlen, seqlens = None, None
if self.attn_backend in {_Backend.FLASH_ATTN, _Backend.ROCM_AITER_FA}: if self.attn_backend in {
AttentionBackendEnum.FLASH_ATTN,
AttentionBackendEnum.ROCM_AITER_FA,
}:
max_seqlen = (cu_seqlens[1:] - cu_seqlens[:-1]).max().item() max_seqlen = (cu_seqlens[1:] - cu_seqlens[:-1]).max().item()
elif self.attn_backend == _Backend.XFORMERS: elif self.attn_backend == AttentionBackendEnum.XFORMERS:
seqlens = (cu_seqlens[1:] - cu_seqlens[:-1]).tolist() seqlens = (cu_seqlens[1:] - cu_seqlens[:-1]).tolist()
return max_seqlen, seqlens return max_seqlen, seqlens

View File

@ -47,7 +47,7 @@ from transformers.models.qwen3_omni_moe.processing_qwen3_omni_moe import (
) )
from transformers.models.whisper import WhisperFeatureExtractor from transformers.models.whisper import WhisperFeatureExtractor
from vllm.attention.backends.registry import _Backend from vllm.attention.backends.registry import AttentionBackendEnum
from vllm.attention.layer import check_upstream_fa_availability from vllm.attention.layer import check_upstream_fa_availability
from vllm.compilation.decorators import support_torch_compile from vllm.compilation.decorators import support_torch_compile
from vllm.config import VllmConfig from vllm.config import VllmConfig
@ -301,7 +301,7 @@ class Qwen3Omni_VisionTransformer(nn.Module):
norm_eps: float = 1e-6, norm_eps: float = 1e-6,
quant_config: QuantizationConfig | None = None, quant_config: QuantizationConfig | None = None,
prefix: str = "", prefix: str = "",
attn_backend_override: _Backend | None = None, attn_backend_override: AttentionBackendEnum | None = None,
) -> None: ) -> None:
super().__init__() super().__init__()
self.hidden_size = vision_config.hidden_size self.hidden_size = vision_config.hidden_size
@ -377,10 +377,11 @@ class Qwen3Omni_VisionTransformer(nn.Module):
dtype=torch.get_default_dtype(), dtype=torch.get_default_dtype(),
attn_backend_override=attn_backend_override, attn_backend_override=attn_backend_override,
) )
if self.attn_backend != _Backend.FLASH_ATTN and check_upstream_fa_availability( if (
torch.get_default_dtype() self.attn_backend != AttentionBackendEnum.FLASH_ATTN
and check_upstream_fa_availability(torch.get_default_dtype())
): ):
self.attn_backend = _Backend.FLASH_ATTN self.attn_backend = AttentionBackendEnum.FLASH_ATTN
@property @property
def dtype(self) -> torch.dtype: def dtype(self) -> torch.dtype:
@ -490,9 +491,9 @@ class Qwen3Omni_VisionTransformer(nn.Module):
) -> tuple[torch.Tensor, torch.Tensor]: ) -> tuple[torch.Tensor, torch.Tensor]:
max_seqlen = torch.zeros([], device=cu_seqlens.device) max_seqlen = torch.zeros([], device=cu_seqlens.device)
seqlens = torch.zeros(1, device=cu_seqlens.device) seqlens = torch.zeros(1, device=cu_seqlens.device)
if self.attn_backend == _Backend.FLASH_ATTN: if self.attn_backend == AttentionBackendEnum.FLASH_ATTN:
max_seqlen = (cu_seqlens[1:] - cu_seqlens[:-1]).max() max_seqlen = (cu_seqlens[1:] - cu_seqlens[:-1]).max()
elif self.attn_backend == _Backend.XFORMERS: elif self.attn_backend == AttentionBackendEnum.XFORMERS:
seqlens = cu_seqlens[1:] - cu_seqlens[:-1] seqlens = cu_seqlens[1:] - cu_seqlens[:-1]
return max_seqlen, seqlens return max_seqlen, seqlens

View File

@ -49,7 +49,7 @@ from transformers.models.qwen3_vl.video_processing_qwen3_vl import (
) )
from transformers.video_utils import VideoMetadata from transformers.video_utils import VideoMetadata
from vllm.attention.backends.registry import _Backend from vllm.attention.backends.registry import AttentionBackendEnum
from vllm.attention.layer import check_upstream_fa_availability from vllm.attention.layer import check_upstream_fa_availability
from vllm.compilation.decorators import support_torch_compile from vllm.compilation.decorators import support_torch_compile
from vllm.config import VllmConfig from vllm.config import VllmConfig
@ -198,7 +198,7 @@ class Qwen3_VisionBlock(nn.Module):
quant_config: QuantizationConfig | None = None, quant_config: QuantizationConfig | None = None,
prefix: str = "", prefix: str = "",
use_data_parallel: bool = False, use_data_parallel: bool = False,
attn_backend: _Backend = _Backend.TORCH_SDPA, attn_backend: AttentionBackendEnum = AttentionBackendEnum.TORCH_SDPA,
use_upstream_fa: bool = False, use_upstream_fa: bool = False,
) -> None: ) -> None:
super().__init__() super().__init__()
@ -306,7 +306,7 @@ class Qwen3_VisionTransformer(nn.Module):
quant_config: QuantizationConfig | None = None, quant_config: QuantizationConfig | None = None,
prefix: str = "", prefix: str = "",
use_data_parallel: bool = False, use_data_parallel: bool = False,
attn_backend_override: _Backend | None = None, attn_backend_override: AttentionBackendEnum | None = None,
) -> None: ) -> None:
super().__init__() super().__init__()
self.hidden_size = vision_config.hidden_size self.hidden_size = vision_config.hidden_size
@ -372,18 +372,18 @@ class Qwen3_VisionTransformer(nn.Module):
) )
use_upstream_fa = False use_upstream_fa = False
if ( if (
self.attn_backend != _Backend.FLASH_ATTN self.attn_backend != AttentionBackendEnum.FLASH_ATTN
and self.attn_backend != _Backend.ROCM_AITER_FA and self.attn_backend != AttentionBackendEnum.ROCM_AITER_FA
and check_upstream_fa_availability(torch.get_default_dtype()) and check_upstream_fa_availability(torch.get_default_dtype())
): ):
self.attn_backend = _Backend.FLASH_ATTN self.attn_backend = AttentionBackendEnum.FLASH_ATTN
use_upstream_fa = True use_upstream_fa = True
if self.attn_backend not in { if self.attn_backend not in {
_Backend.FLASH_ATTN, AttentionBackendEnum.FLASH_ATTN,
_Backend.TORCH_SDPA, AttentionBackendEnum.TORCH_SDPA,
_Backend.XFORMERS, AttentionBackendEnum.XFORMERS,
_Backend.ROCM_AITER_FA, AttentionBackendEnum.ROCM_AITER_FA,
}: }:
raise RuntimeError( raise RuntimeError(
f"Qwen3-VL does not support {self.attn_backend} backend now." f"Qwen3-VL does not support {self.attn_backend} backend now."
@ -510,11 +510,11 @@ class Qwen3_VisionTransformer(nn.Module):
max_seqlen = torch.zeros([], device=cu_seqlens.device) max_seqlen = torch.zeros([], device=cu_seqlens.device)
seqlens = torch.zeros(1, device=cu_seqlens.device) seqlens = torch.zeros(1, device=cu_seqlens.device)
if ( if (
self.attn_backend == _Backend.FLASH_ATTN self.attn_backend == AttentionBackendEnum.FLASH_ATTN
or self.attn_backend == _Backend.ROCM_AITER_FA or self.attn_backend == AttentionBackendEnum.ROCM_AITER_FA
): ):
max_seqlen = (cu_seqlens[1:] - cu_seqlens[:-1]).max() max_seqlen = (cu_seqlens[1:] - cu_seqlens[:-1]).max()
elif self.attn_backend == _Backend.XFORMERS: elif self.attn_backend == AttentionBackendEnum.XFORMERS:
seqlens = cu_seqlens[1:] - cu_seqlens[:-1] seqlens = cu_seqlens[1:] - cu_seqlens[:-1]
return max_seqlen, seqlens return max_seqlen, seqlens

View File

@ -12,7 +12,7 @@ from torch.nn import functional as F
from transformers import Siglip2VisionConfig from transformers import Siglip2VisionConfig
from transformers.configuration_utils import PretrainedConfig from transformers.configuration_utils import PretrainedConfig
from vllm.attention.backends.registry import _Backend from vllm.attention.backends.registry import AttentionBackendEnum
from vllm.attention.layer import maybe_get_vit_flash_attn_backend from vllm.attention.layer import maybe_get_vit_flash_attn_backend
from vllm.distributed import divide, get_tensor_model_parallel_world_size from vllm.distributed import divide, get_tensor_model_parallel_world_size
from vllm.model_executor.layers.activation import get_act_fn from vllm.model_executor.layers.activation import get_act_fn
@ -208,7 +208,7 @@ class Siglip2Attention(nn.Module):
quant_config: QuantizationConfig | None = None, quant_config: QuantizationConfig | None = None,
prefix: str = "", prefix: str = "",
use_data_parallel: bool = False, use_data_parallel: bool = False,
attn_backend_override: _Backend | None = None, attn_backend_override: AttentionBackendEnum | None = None,
): ):
super().__init__() super().__init__()
self.config = config self.config = config
@ -264,14 +264,14 @@ class Siglip2Attention(nn.Module):
) )
if self.attn_backend not in { if self.attn_backend not in {
_Backend.FLASH_ATTN, AttentionBackendEnum.FLASH_ATTN,
_Backend.TORCH_SDPA, AttentionBackendEnum.TORCH_SDPA,
_Backend.ROCM_AITER_FA, AttentionBackendEnum.ROCM_AITER_FA,
}: }:
self.attn_backend = _Backend.TORCH_SDPA self.attn_backend = AttentionBackendEnum.TORCH_SDPA
self.is_flash_attn_backend = self.attn_backend in { self.is_flash_attn_backend = self.attn_backend in {
_Backend.FLASH_ATTN, AttentionBackendEnum.FLASH_ATTN,
_Backend.ROCM_AITER_FA, AttentionBackendEnum.ROCM_AITER_FA,
} }
def forward( def forward(
@ -308,7 +308,7 @@ class Siglip2Attention(nn.Module):
attn_output = self.flash_attn_varlen_func( attn_output = self.flash_attn_varlen_func(
queries, keys, values, cu_seqlens, cu_seqlens, max_seqlen, max_seqlen queries, keys, values, cu_seqlens, cu_seqlens, max_seqlen, max_seqlen
).reshape(seq_length, -1) ).reshape(seq_length, -1)
elif self.attn_backend == _Backend.TORCH_SDPA: elif self.attn_backend == AttentionBackendEnum.TORCH_SDPA:
# Execute attention entry by entry for speed & less VRAM. # Execute attention entry by entry for speed & less VRAM.
batch_size = cu_seqlens.shape[0] - 1 batch_size = cu_seqlens.shape[0] - 1
outputs = [] outputs = []
@ -376,7 +376,7 @@ class Siglip2EncoderLayer(nn.Module):
quant_config: QuantizationConfig | None = None, quant_config: QuantizationConfig | None = None,
prefix: str = "", prefix: str = "",
use_data_parallel: bool = False, use_data_parallel: bool = False,
attn_backend_override: _Backend | None = None, attn_backend_override: AttentionBackendEnum | None = None,
): ):
super().__init__() super().__init__()
self.embed_dim = config.hidden_size self.embed_dim = config.hidden_size
@ -440,7 +440,7 @@ class Siglip2Encoder(nn.Module):
quant_config: QuantizationConfig | None = None, quant_config: QuantizationConfig | None = None,
prefix: str = "", prefix: str = "",
use_data_parallel: bool = False, use_data_parallel: bool = False,
attn_backend_override: _Backend | None = None, attn_backend_override: AttentionBackendEnum | None = None,
): ):
super().__init__() super().__init__()
self.config = config self.config = config
@ -626,7 +626,7 @@ class Siglip2VisionTransformer(nn.Module):
quant_config: QuantizationConfig | None = None, quant_config: QuantizationConfig | None = None,
prefix: str = "", prefix: str = "",
use_data_parallel: bool = False, use_data_parallel: bool = False,
attn_backend_override: _Backend | None = None, attn_backend_override: AttentionBackendEnum | None = None,
): ):
super().__init__() super().__init__()
self.config = config self.config = config
@ -667,7 +667,7 @@ class Siglip2NavitModel(torch.nn.Module):
quant_config: QuantizationConfig | None = None, quant_config: QuantizationConfig | None = None,
prefix: str = "", prefix: str = "",
use_data_parallel: bool = False, use_data_parallel: bool = False,
attn_backend_override: _Backend | None = None, attn_backend_override: AttentionBackendEnum | None = None,
): ):
super().__init__() super().__init__()

View File

@ -10,7 +10,7 @@ from typing import Final, Generic, Literal, Protocol, TypeAlias, TypeVar
import torch import torch
from transformers import PretrainedConfig from transformers import PretrainedConfig
from vllm.attention.backends.registry import _Backend from vllm.attention.backends.registry import AttentionBackendEnum
from vllm.config import VllmConfig from vllm.config import VllmConfig
from vllm.distributed import ( from vllm.distributed import (
get_tensor_model_parallel_rank, get_tensor_model_parallel_rank,
@ -83,8 +83,8 @@ def get_vit_attn_backend(
head_size: int, head_size: int,
dtype: torch.dtype, dtype: torch.dtype,
*, *,
attn_backend_override: _Backend | None = None, attn_backend_override: AttentionBackendEnum | None = None,
) -> _Backend: ) -> AttentionBackendEnum:
""" """
Get the available attention backend for Vision Transformer. Get the available attention backend for Vision Transformer.
""" """
@ -94,7 +94,7 @@ def get_vit_attn_backend(
# Lazy import to avoid circular dependency # Lazy import to avoid circular dependency
from vllm.attention.selector import get_env_variable_attn_backend from vllm.attention.selector import get_env_variable_attn_backend
selected_backend: _Backend | None = get_env_variable_attn_backend() selected_backend: AttentionBackendEnum | None = get_env_variable_attn_backend()
if selected_backend is not None: if selected_backend is not None:
return selected_backend return selected_backend

View File

@ -23,10 +23,10 @@ from .interface import CpuArchEnum, Platform, PlatformEnum
logger = init_logger(__name__) logger = init_logger(__name__)
if TYPE_CHECKING: if TYPE_CHECKING:
from vllm.attention.backends.registry import _Backend from vllm.attention.backends.registry import AttentionBackendEnum
from vllm.config import VllmConfig from vllm.config import VllmConfig
else: else:
_Backend = None AttentionBackendEnum = None
VllmConfig = None VllmConfig = None
@ -127,7 +127,7 @@ class CpuPlatform(Platform):
@classmethod @classmethod
def get_attn_backend_cls( def get_attn_backend_cls(
cls, cls,
selected_backend: "_Backend", selected_backend: "AttentionBackendEnum",
head_size: int, head_size: int,
dtype: torch.dtype, dtype: torch.dtype,
kv_cache_dtype: str | None, kv_cache_dtype: str | None,
@ -137,9 +137,9 @@ class CpuPlatform(Platform):
has_sink: bool, has_sink: bool,
use_sparse: bool, use_sparse: bool,
) -> str: ) -> str:
from vllm.attention.backends.registry import _Backend from vllm.attention.backends.registry import AttentionBackendEnum
if selected_backend and selected_backend != _Backend.TORCH_SDPA: if selected_backend and selected_backend != AttentionBackendEnum.TORCH_SDPA:
logger.info("Cannot use %s backend on CPU.", selected_backend) logger.info("Cannot use %s backend on CPU.", selected_backend)
if use_mla: if use_mla:
raise NotImplementedError("MLA is not supported on CPU.") raise NotImplementedError("MLA is not supported on CPU.")
@ -148,7 +148,7 @@ class CpuPlatform(Platform):
logger.info("Using Torch SDPA backend.") logger.info("Using Torch SDPA backend.")
if not use_v1: if not use_v1:
raise ValueError("CPU backend only supports V1.") raise ValueError("CPU backend only supports V1.")
return "vllm.v1.attention.backends.cpu_attn.TorchSDPABackend" return AttentionBackendEnum.TORCH_SDPA.get_path()
@classmethod @classmethod
def get_device_total_memory(cls, device_id: int = 0) -> int: def get_device_total_memory(cls, device_id: int = 0) -> int:

View File

@ -22,10 +22,13 @@ from vllm.utils.torch_utils import cuda_device_count_stateless
from .interface import DeviceCapability, Platform, PlatformEnum from .interface import DeviceCapability, Platform, PlatformEnum
if TYPE_CHECKING: if TYPE_CHECKING:
from vllm.attention.backends.registry import _Backend from vllm.attention.backends.registry import AttentionBackendEnum
from vllm.config import VllmConfig from vllm.config import VllmConfig
from vllm.config.cache import CacheDType
else: else:
_Backend = None AttentionBackendEnum = None
VllmConfig = None
CacheDType = None
logger = init_logger(__name__) logger = init_logger(__name__)
@ -39,6 +42,49 @@ pynvml = import_pynvml()
torch.backends.cuda.enable_cudnn_sdp(False) torch.backends.cuda.enable_cudnn_sdp(False)
@cache
def _get_backend_priorities(
use_mla: bool,
device_capability: DeviceCapability,
) -> list[AttentionBackendEnum]:
"""Get backend priorities with lazy import to avoid circular dependency."""
from vllm.attention.backends.registry import AttentionBackendEnum
if use_mla:
if device_capability.major == 10:
return [
AttentionBackendEnum.CUTLASS_MLA,
AttentionBackendEnum.FLASHINFER_MLA,
AttentionBackendEnum.FLASHMLA,
AttentionBackendEnum.FLASH_ATTN_MLA,
AttentionBackendEnum.TRITON_MLA,
AttentionBackendEnum.FLASHMLA_SPARSE,
]
else:
return [
AttentionBackendEnum.FLASHMLA,
AttentionBackendEnum.FLASH_ATTN_MLA,
AttentionBackendEnum.FLASHINFER_MLA,
AttentionBackendEnum.TRITON_MLA,
AttentionBackendEnum.FLASHMLA_SPARSE,
]
else:
if device_capability.major == 10:
return [
AttentionBackendEnum.FLASHINFER,
AttentionBackendEnum.FLASH_ATTN,
AttentionBackendEnum.TRITON_ATTN,
AttentionBackendEnum.FLEX_ATTENTION,
]
else:
return [
AttentionBackendEnum.FLASH_ATTN,
AttentionBackendEnum.FLASHINFER,
AttentionBackendEnum.TRITON_ATTN,
AttentionBackendEnum.FLEX_ATTENTION,
]
def with_nvml_context(fn: Callable[_P, _R]) -> Callable[_P, _R]: def with_nvml_context(fn: Callable[_P, _R]) -> Callable[_P, _R]:
@wraps(fn) @wraps(fn)
def wrapper(*args: _P.args, **kwargs: _P.kwargs) -> _R: def wrapper(*args: _P.args, **kwargs: _P.kwargs) -> _R:
@ -216,217 +262,171 @@ class CudaPlatformBase(Platform):
return torch.cuda.max_memory_allocated(device) return torch.cuda.max_memory_allocated(device)
@classmethod @classmethod
def get_vit_attn_backend(cls, head_size: int, dtype: torch.dtype) -> "_Backend": def get_vit_attn_backend(
from vllm.attention.backends.registry import _Backend cls, head_size: int, dtype: torch.dtype
) -> "AttentionBackendEnum":
from vllm.attention.backends.registry import AttentionBackendEnum
# For Blackwell GPUs, force TORCH_SDPA for now. # For Blackwell GPUs, force TORCH_SDPA for now.
# See https://github.com/facebookresearch/xformers/issues/1317#issuecomment-3199392579 # noqa: E501 # See https://github.com/facebookresearch/xformers/issues/1317#issuecomment-3199392579 # noqa: E501
if cls.has_device_capability(100): if cls.has_device_capability(100):
return _Backend.TORCH_SDPA return AttentionBackendEnum.TORCH_SDPA
if dtype not in (torch.float16, torch.bfloat16): if dtype not in (torch.float16, torch.bfloat16):
return _Backend.XFORMERS return AttentionBackendEnum.XFORMERS
if cls.has_device_capability(80): if cls.has_device_capability(80):
FLASH_ATTN_V1 = ( backend_class = AttentionBackendEnum.FLASH_ATTN.get_class()
"vllm.v1.attention.backends.flash_attn.FlashAttentionBackend" # noqa: E501 if backend_class.supports_head_size(
) head_size
from vllm.attention.selector import is_attn_backend_supported ) and backend_class.supports_dtype(dtype):
return AttentionBackendEnum.FLASH_ATTN
is_default_fa_supported = is_attn_backend_supported(
FLASH_ATTN_V1, head_size, dtype, allow_import_error=False
)
if is_default_fa_supported:
return _Backend.FLASH_ATTN
else: else:
# Fallback to XFORMERS return AttentionBackendEnum.XFORMERS
return _Backend.XFORMERS
else: else:
# Fallback for Volta/Turing GPUs or FA not supported # Fallback for Volta/Turing GPUs or FA not supported
return _Backend.XFORMERS return AttentionBackendEnum.XFORMERS
@classmethod @classmethod
def get_attn_backend_cls( def get_valid_backends(
cls, cls,
selected_backend,
head_size, head_size,
dtype, dtype,
kv_cache_dtype, kv_cache_dtype,
block_size, block_size,
use_v1,
use_mla, use_mla,
has_sink, has_sink,
use_sparse, use_sparse,
device_capability,
) -> tuple[
list[tuple["AttentionBackendEnum", int]],
dict["AttentionBackendEnum", list[str]],
]:
valid_backends_priorities = []
invalid_reasons = {}
backend_priorities = _get_backend_priorities(use_mla, device_capability)
for priority, backend in enumerate(backend_priorities):
try:
backend_class = backend.get_class()
invalid_reasons_i = backend_class.validate_configuration(
head_size,
dtype,
kv_cache_dtype,
block_size,
use_mla,
has_sink,
use_sparse,
device_capability,
)
except ImportError:
invalid_reasons_i = ["ImportError"]
if invalid_reasons_i:
invalid_reasons[backend] = invalid_reasons_i
else:
valid_backends_priorities.append((backend, priority))
return valid_backends_priorities, invalid_reasons
@classmethod
def get_attn_backend_cls(
cls,
selected_backend: "AttentionBackendEnum",
head_size: int,
dtype: torch.dtype,
kv_cache_dtype: "CacheDType | None",
block_size: int | None,
use_v1: bool,
use_mla: bool,
has_sink: bool,
use_sparse: bool,
) -> str: ) -> str:
from vllm.attention.backends.registry import _Backend if not use_v1:
raise RuntimeError(
"V0 attention backends have been removed. Set VLLM_USE_V1=1 "
"to select a supported backend."
)
if use_mla: device_capability = cls.get_device_capability()
# explicitly reject non-MLA backends when MLA is enabled to avoid assert device_capability is not None
# silently selecting an incompatible backend (e.g., FLASHINFER).
if selected_backend in { # First try checking just the selected backend, if there is one.
_Backend.FLASHINFER, if selected_backend is not None:
_Backend.FLASH_ATTN, try:
_Backend.TRITON_ATTN, backend_class = selected_backend.get_class()
_Backend.TREE_ATTN, invalid_reasons = backend_class.validate_configuration(
_Backend.XFORMERS, head_size,
}: dtype,
kv_cache_dtype,
None,
use_mla,
has_sink,
use_sparse,
device_capability,
)
except ImportError:
invalid_reasons = ["ImportError"]
if invalid_reasons:
raise ValueError( raise ValueError(
f"Attention backend {selected_backend} incompatible with MLA. " f"Selected backend {selected_backend} is not valid for "
"Please use one of the MLA backends: FLASHINFER_MLA, CUTLASS_MLA, " f"this configuration. Reason: {invalid_reasons}"
"FLASHMLA, FLASH_ATTN_MLA, or TRITON_MLA. Alternatively, set "
"VLLM_MLA_DISABLE=1 to disable MLA for this model."
) )
else:
logger.info("Using %s backend.", selected_backend)
return selected_backend.get_path()
from vllm.attention.ops.flashmla import is_flashmla_dense_supported # No selected backend or the selected backend is invalid,
from vllm.attention.utils.fa_utils import flash_attn_supports_mla # so we try finding a valid backend.
valid_backends_priorities, invalid_reasons = cls.get_valid_backends(
if use_sparse: head_size,
logger.info_once("Using Sparse MLA backend.") dtype,
return ( kv_cache_dtype,
"vllm.v1.attention.backends.mla.flashmla_sparse." None,
"FlashMLASparseBackend" use_mla,
) has_sink,
use_sparse,
use_cutlassmla = selected_backend == _Backend.CUTLASS_MLA or ( device_capability,
selected_backend is None
and cls.is_device_capability(100)
and block_size % 128 == 0
)
use_flashinfermla = selected_backend == _Backend.FLASHINFER_MLA or (
selected_backend is None
and cls.is_device_capability(100)
and (block_size == 32 or block_size % 64 == 0)
)
use_flashmla = selected_backend == _Backend.FLASHMLA or (
selected_backend is None and is_flashmla_dense_supported()[0]
)
use_flashattn = selected_backend == _Backend.FLASH_ATTN_MLA or (
selected_backend is None and flash_attn_supports_mla()
)
use_triton = selected_backend == _Backend.TRITON_MLA or (
selected_backend is None
)
if use_cutlassmla:
logger.info_once("Using Cutlass MLA backend.", scope="local")
return "vllm.v1.attention.backends.mla.cutlass_mla.CutlassMLABackend"
if use_flashinfermla:
from vllm.v1.attention.backends.utils import set_kv_cache_layout
set_kv_cache_layout("HND")
logger.info_once("Using FlashInfer MLA backend.")
return (
"vllm.v1.attention.backends.mla.flashinfer_mla.FlashInferMLABackend"
)
if use_flashmla:
if block_size % 64 != 0:
logger.warning(
"FlashMLA backend is not supported for block size %d"
" (currently only supports block size 64).",
block_size,
)
else:
logger.info_once("Using FlashMLA backend.")
return "vllm.v1.attention.backends.mla.flashmla.FlashMLABackend"
if use_flashattn:
logger.info_once("Using FlashAttention MLA backend.")
return (
"vllm.v1.attention.backends.mla.flashattn_mla.FlashAttnMLABackend"
)
if use_triton:
logger.info_once("Using Triton MLA backend.")
return "vllm.v1.attention.backends.mla.triton_mla.TritonMLABackend"
FLASHINFER_V1 = "vllm.v1.attention.backends.flashinfer.FlashInferBackend" # noqa: E501
FLEX_ATTENTION_V1 = (
"vllm.v1.attention.backends.flex_attention.FlexAttentionBackend" # noqa: E501
) )
TRITON_ATTN = "vllm.v1.attention.backends.triton_attn.TritonAttentionBackend" # noqa: E501 reasons_str = (
FLASH_ATTN_V1 = "vllm.v1.attention.backends.flash_attn.FlashAttentionBackend" # noqa: E501 "{"
TREE_ATTN_V1 = "vllm.v1.attention.backends.tree_attn.TreeAttentionBackend" # noqa: E501 + ", ".join(
XFORMERS_V1 = "vllm.v1.attention.backends.xformers.XFormersAttentionBackend" # noqa: E501 f"{backend.name}: [{', '.join(reasons)}]"
for backend, reasons in invalid_reasons.items()
)
+ "}"
)
config_str = (
f"head_size: {head_size}, dtype: {dtype}, "
f"kv_cache_dtype: {kv_cache_dtype}, block_size: {block_size}, "
f"use_mla: {use_mla}, has_sink: {has_sink}, use_sparse: {use_sparse}"
)
logger.debug_once(
f"Some attention backends are not valid for {cls.device_name} with "
f"{config_str}. Reasons: {reasons_str}."
)
if len(valid_backends_priorities) == 0:
raise ValueError(
f"No valid attention backend found for {cls.device_name} "
f"with {config_str}. Reasons: {reasons_str}."
)
use_fp8_kv_cache = kv_cache_dtype is not None and kv_cache_dtype.startswith( # We have found some valid backends. Select the one with the
"fp8" # highest priority.
logger.info(
"Valid backends: %s", [b[0].name for b in valid_backends_priorities]
)
sorted_indices = sorted(
range(len(valid_backends_priorities)),
key=lambda i: valid_backends_priorities[i][1],
)
selected_index = sorted_indices[0]
selected_backend = valid_backends_priorities[selected_index][0]
logger.info(
"Using %s backend.",
selected_backend.name,
) )
if selected_backend == _Backend.FLASHINFER: return selected_backend.get_path()
logger.info_once("Using FlashInfer backend.")
if cls.has_device_capability(100):
from vllm.v1.attention.backends.utils import set_kv_cache_layout
set_kv_cache_layout("HND")
return FLASHINFER_V1
elif selected_backend == _Backend.FLEX_ATTENTION:
logger.info_once("Using FlexAttention backend.")
return FLEX_ATTENTION_V1
elif selected_backend == _Backend.TRITON_ATTN:
logger.info_once("Using Triton backend.")
return TRITON_ATTN
elif selected_backend == _Backend.FLASH_ATTN:
logger.info_once("Using Flash Attention backend.")
return FLASH_ATTN_V1
elif selected_backend == _Backend.TREE_ATTN:
logger.info_once("Using Tree Attention backend.")
return TREE_ATTN_V1
elif selected_backend == _Backend.XFORMERS:
logger.info_once("Using XFormers backend.")
return XFORMERS_V1
from vllm.attention.selector import is_attn_backend_supported
# Default backends for V1 engine
# Prefer FlashInfer for Blackwell GPUs if installed
if cls.is_device_capability(100):
if is_default_backend_supported := is_attn_backend_supported(
FLASHINFER_V1, head_size, dtype
):
from vllm.v1.attention.backends.utils import set_kv_cache_layout
logger.info_once(
"Using FlashInfer backend with HND KV cache layout on "
"V1 engine by default for Blackwell (SM 10.0) GPUs."
)
set_kv_cache_layout("HND")
return FLASHINFER_V1
if not is_default_backend_supported.can_import:
logger.warning_once(
"FlashInfer failed to import on Blackwell (SM 10.0) GPUs; "
"it is recommended to install FlashInfer for better "
"performance."
)
# FlashAttention is the default for SM 8.0+ GPUs
if cls.has_device_capability(80):
if (has_sink or use_fp8_kv_cache) and not cls.is_device_capability(90):
logger.info_once("Using Triton backend.")
return TRITON_ATTN
elif is_default_backend_supported := is_attn_backend_supported(
FLASH_ATTN_V1, head_size, dtype, allow_import_error=False
):
logger.info_once("Using Flash Attention backend.")
return FLASH_ATTN_V1
# FlexAttention is the default for older GPUs
else:
logger.info_once("Using FlexAttention backend.")
return FLEX_ATTENTION_V1
assert not is_default_backend_supported
use_flex_attention_reason = {}
if not is_default_backend_supported.head_size:
use_flex_attention_reason["head_size"] = head_size
if not is_default_backend_supported.dtype:
use_flex_attention_reason["dtype"] = dtype
logger.info_once(
"Using FlexAttention backend for %s.",
", ".join(f"{k}={v}" for k, v in use_flex_attention_reason.items()),
)
return FLEX_ATTENTION_V1
@classmethod @classmethod
def get_punica_wrapper(cls) -> str: def get_punica_wrapper(cls) -> str:

View File

@ -17,8 +17,9 @@ from vllm.logger import init_logger
if TYPE_CHECKING: if TYPE_CHECKING:
from torch.distributed import PrefixStore, ProcessGroup from torch.distributed import PrefixStore, ProcessGroup
from vllm.attention.backends.registry import _Backend from vllm.attention.backends.registry import AttentionBackendEnum
from vllm.config import VllmConfig from vllm.config import VllmConfig
from vllm.config.cache import CacheDType
from vllm.inputs import ProcessorInputs, PromptType from vllm.inputs import ProcessorInputs, PromptType
from vllm.pooling_params import PoolingParams from vllm.pooling_params import PoolingParams
from vllm.sampling_params import SamplingParams from vllm.sampling_params import SamplingParams
@ -58,6 +59,31 @@ class DeviceCapability(NamedTuple):
major: int major: int
minor: int minor: int
def __lt__(self, other: Any) -> bool:
if not isinstance(other, DeviceCapability):
return NotImplemented
return (self.major, self.minor) < (other.major, other.minor)
def __le__(self, other: Any) -> bool:
if not isinstance(other, DeviceCapability):
return NotImplemented
return (self.major, self.minor) <= (other.major, other.minor)
def __eq__(self, other: Any) -> bool:
if not isinstance(other, DeviceCapability):
return NotImplemented
return (self.major, self.minor) == (other.major, other.minor)
def __ge__(self, other: Any) -> bool:
if not isinstance(other, DeviceCapability):
return NotImplemented
return (self.major, self.minor) >= (other.major, other.minor)
def __gt__(self, other: Any) -> bool:
if not isinstance(other, DeviceCapability):
return NotImplemented
return (self.major, self.minor) > (other.major, other.minor)
def as_version_str(self) -> str: def as_version_str(self) -> str:
return f"{self.major}.{self.minor}" return f"{self.major}.{self.minor}"
@ -173,19 +199,21 @@ class Platform:
import vllm._moe_C # noqa: F401 import vllm._moe_C # noqa: F401
@classmethod @classmethod
def get_vit_attn_backend(cls, head_size: int, dtype: torch.dtype) -> "_Backend": def get_vit_attn_backend(
# Import _Backend here to avoid circular import. cls, head_size: int, dtype: torch.dtype
from vllm.attention.backends.registry import _Backend ) -> "AttentionBackendEnum":
# Import AttentionBackendEnum here to avoid circular import.
from vllm.attention.backends.registry import AttentionBackendEnum
return _Backend.TORCH_SDPA return AttentionBackendEnum.TORCH_SDPA
@classmethod @classmethod
def get_attn_backend_cls( def get_attn_backend_cls(
cls, cls,
selected_backend: "_Backend", selected_backend: "AttentionBackendEnum",
head_size: int, head_size: int,
dtype: torch.dtype, dtype: torch.dtype,
kv_cache_dtype: str | None, kv_cache_dtype: "CacheDType | None",
block_size: int, block_size: int,
use_v1: bool, use_v1: bool,
use_mla: bool, use_mla: bool,

View File

@ -14,10 +14,10 @@ from vllm.utils.torch_utils import cuda_device_count_stateless
from .interface import DeviceCapability, Platform, PlatformEnum from .interface import DeviceCapability, Platform, PlatformEnum
if TYPE_CHECKING: if TYPE_CHECKING:
from vllm.attention.backends.registry import _Backend from vllm.attention.backends.registry import AttentionBackendEnum
from vllm.config import VllmConfig from vllm.config import VllmConfig
else: else:
_Backend = None AttentionBackendEnum = None
logger = init_logger(__name__) logger = init_logger(__name__)
@ -204,21 +204,23 @@ class RocmPlatform(Platform):
] ]
@classmethod @classmethod
def get_vit_attn_backend(cls, head_size: int, dtype: torch.dtype) -> _Backend: def get_vit_attn_backend(
cls, head_size: int, dtype: torch.dtype
) -> AttentionBackendEnum:
from importlib.util import find_spec from importlib.util import find_spec
from vllm._aiter_ops import rocm_aiter_ops from vllm._aiter_ops import rocm_aiter_ops
from vllm.attention.backends.registry import _Backend from vllm.attention.backends.registry import AttentionBackendEnum
if rocm_aiter_ops.is_mha_enabled(): if rocm_aiter_ops.is_mha_enabled():
# Note: AITER FA is only supported for Qwen-VL models. # Note: AITER FA is only supported for Qwen-VL models.
# TODO: Add support for other VL models in their model class. # TODO: Add support for other VL models in their model class.
return _Backend.ROCM_AITER_FA return AttentionBackendEnum.ROCM_AITER_FA
if on_gfx9() and find_spec("flash_attn") is not None: if on_gfx9() and find_spec("flash_attn") is not None:
return _Backend.FLASH_ATTN return AttentionBackendEnum.FLASH_ATTN
return _Backend.TORCH_SDPA return AttentionBackendEnum.TORCH_SDPA
@classmethod @classmethod
def get_attn_backend_cls( def get_attn_backend_cls(
@ -234,7 +236,7 @@ class RocmPlatform(Platform):
use_sparse, use_sparse,
) -> str: ) -> str:
from vllm._aiter_ops import rocm_aiter_ops from vllm._aiter_ops import rocm_aiter_ops
from vllm.attention.backends.registry import _Backend from vllm.attention.backends.registry import AttentionBackendEnum
if use_sparse: if use_sparse:
raise NotImplementedError("Sparse Attention is not supported on ROCm.") raise NotImplementedError("Sparse Attention is not supported on ROCm.")
@ -248,55 +250,52 @@ class RocmPlatform(Platform):
if use_mla: if use_mla:
if selected_backend is None: if selected_backend is None:
selected_backend = ( selected_backend = (
_Backend.ROCM_AITER_MLA AttentionBackendEnum.ROCM_AITER_MLA
if rocm_aiter_ops.is_mla_enabled() or block_size == 1 if rocm_aiter_ops.is_mla_enabled() or block_size == 1
else _Backend.TRITON_MLA else AttentionBackendEnum.TRITON_MLA
) )
if selected_backend == _Backend.TRITON_MLA: if selected_backend == AttentionBackendEnum.TRITON_MLA:
if block_size != 1: if block_size != 1:
logger.info_once("Using Triton MLA backend.") logger.info_once("Using Triton MLA backend.")
return "vllm.v1.attention.backends.mla.triton_mla.TritonMLABackend" return AttentionBackendEnum.TRITON_MLA.get_path()
raise ValueError( raise ValueError(
f" The selected backend, {selected_backend.name}," f" The selected backend, {selected_backend.name},"
f"does not support block size {block_size}." f"does not support block size {block_size}."
) )
if selected_backend == _Backend.ROCM_AITER_MLA: if selected_backend == AttentionBackendEnum.ROCM_AITER_MLA:
logger.info("Using AITER MLA backend.") logger.info("Using AITER MLA backend.")
return "vllm.v1.attention.backends.mla.rocm_aiter_mla.AiterMLABackend" # noqa: E501 return AttentionBackendEnum.ROCM_AITER_MLA.get_path()
raise ValueError( raise ValueError(
f" The selected backend, {selected_backend.name}," f" The selected backend, {selected_backend.name},"
f"is not MLA type while requested for MLA backend." f"is not MLA type while requested for MLA backend."
) )
if selected_backend == _Backend.FLEX_ATTENTION: if selected_backend == AttentionBackendEnum.FLEX_ATTENTION:
logger.info("Using FlexAttention backend.") logger.info("Using FlexAttention backend.")
return "vllm.v1.attention.backends.flex_attention.FlexAttentionBackend" return "vllm.v1.attention.backends.flex_attention.FlexAttentionBackend"
if ( if (
rocm_aiter_ops.is_mha_enabled() rocm_aiter_ops.is_mha_enabled()
) or selected_backend == _Backend.ROCM_AITER_FA: ) or selected_backend == AttentionBackendEnum.ROCM_AITER_FA:
logger.info("Using Aiter Flash Attention backend.") logger.info("Using Aiter Flash Attention backend.")
return "vllm.v1.attention.backends.rocm_aiter_fa.AiterFlashAttentionBackend" return AttentionBackendEnum.ROCM_AITER_FA.get_path()
if ( if (
rocm_aiter_ops.is_triton_unified_attn_enabled() rocm_aiter_ops.is_triton_unified_attn_enabled()
) or selected_backend == _Backend.ROCM_AITER_UNIFIED_ATTN: ) or selected_backend == AttentionBackendEnum.ROCM_AITER_UNIFIED_ATTN:
logger.info("Using Aiter Unified Attention backend.") logger.info("Using Aiter Unified Attention backend.")
return ( return AttentionBackendEnum.ROCM_AITER_UNIFIED_ATTN.get_path()
"vllm.v1.attention.backends."
"rocm_aiter_unified_attn.RocmAiterUnifiedAttentionBackend"
)
if ( if (
envs.VLLM_V1_USE_PREFILL_DECODE_ATTENTION envs.VLLM_V1_USE_PREFILL_DECODE_ATTENTION
or selected_backend == _Backend.ROCM_ATTN or selected_backend == AttentionBackendEnum.ROCM_ATTN
): ):
# rocm specific backend, with aiter and/or # rocm specific backend, with aiter and/or
# triton prefix-prefill # triton prefix-prefill
logger.info("Using Rocm Attention backend.") logger.info("Using Rocm Attention backend.")
return "vllm.v1.attention.backends.rocm_attn.RocmAttentionBackend" return AttentionBackendEnum.ROCM_ATTN.get_path()
# default case, using triton unified attention # default case, using triton unified attention
logger.info("Using Triton Attention backend.") logger.info("Using Triton Attention backend.")
return "vllm.v1.attention.backends.triton_attn.TritonAttentionBackend" return AttentionBackendEnum.TRITON_ATTN.get_path()
@classmethod @classmethod
def set_device(cls, device: torch.device) -> None: def set_device(cls, device: torch.device) -> None:

View File

@ -15,16 +15,15 @@ from vllm.utils import DEFAULT_MAX_NUM_BATCHED_TOKENS
from .interface import Platform, PlatformEnum from .interface import Platform, PlatformEnum
if TYPE_CHECKING: if TYPE_CHECKING:
from vllm.attention.backends.registry import _Backend from vllm.attention.backends.registry import AttentionBackendEnum
from vllm.config import ModelConfig, VllmConfig from vllm.config import VllmConfig
from vllm.config.cache import BlockSize from vllm.config.cache import BlockSize
from vllm.pooling_params import PoolingParams from vllm.pooling_params import PoolingParams
else: else:
BlockSize = None BlockSize = None
ModelConfig = None
VllmConfig = None VllmConfig = None
PoolingParams = None PoolingParams = None
_Backend = None AttentionBackendEnum = None
logger = init_logger(__name__) logger = init_logger(__name__)
@ -54,7 +53,7 @@ class TpuPlatform(Platform):
@classmethod @classmethod
def get_attn_backend_cls( def get_attn_backend_cls(
cls, cls,
selected_backend: "_Backend", selected_backend: "AttentionBackendEnum",
head_size: int, head_size: int,
dtype: torch.dtype, dtype: torch.dtype,
kv_cache_dtype: str | None, kv_cache_dtype: str | None,
@ -64,17 +63,17 @@ class TpuPlatform(Platform):
has_sink, has_sink,
use_sparse, use_sparse,
) -> str: ) -> str:
from vllm.attention.backends.registry import _Backend from vllm.attention.backends.registry import AttentionBackendEnum
if use_sparse: if use_sparse:
raise NotImplementedError("Sparse Attention is not supported on TPU.") raise NotImplementedError("Sparse Attention is not supported on TPU.")
if selected_backend != _Backend.PALLAS: if selected_backend != AttentionBackendEnum.PALLAS:
logger.info("Cannot use %s backend on TPU.", selected_backend) logger.info("Cannot use %s backend on TPU.", selected_backend)
if not use_v1: if not use_v1:
raise ValueError("TPU backend only supports V1.") raise ValueError("TPU backend only supports V1.")
logger.info("Using Pallas V1 backend.") logger.info("Using Pallas V1 backend.")
return "vllm.v1.attention.backends.pallas.PallasAttentionBackend" return AttentionBackendEnum.PALLAS.get_path()
@classmethod @classmethod
def set_device(cls, device: torch.device) -> None: def set_device(cls, device: torch.device) -> None:

View File

@ -14,12 +14,11 @@ from vllm.utils import DEFAULT_MAX_NUM_BATCHED_TOKENS
from .interface import DeviceCapability, Platform, PlatformEnum from .interface import DeviceCapability, Platform, PlatformEnum
if TYPE_CHECKING: if TYPE_CHECKING:
from vllm.attention.backends.registry import _Backend from vllm.attention.backends.registry import AttentionBackendEnum
from vllm.config import ModelConfig, VllmConfig from vllm.config import VllmConfig
else: else:
ModelConfig = None
VllmConfig = None VllmConfig = None
_Backend = None AttentionBackendEnum = None
logger = init_logger(__name__) logger = init_logger(__name__)
@ -44,7 +43,7 @@ class XPUPlatform(Platform):
@classmethod @classmethod
def get_attn_backend_cls( def get_attn_backend_cls(
cls, cls,
selected_backend: "_Backend", selected_backend: "AttentionBackendEnum",
head_size: int, head_size: int,
dtype: torch.dtype, dtype: torch.dtype,
kv_cache_dtype: str | None, kv_cache_dtype: str | None,
@ -62,18 +61,19 @@ class XPUPlatform(Platform):
"only NHD layout is supported by XPU attention kernels." "only NHD layout is supported by XPU attention kernels."
) )
from vllm.attention.backends.registry import _Backend from vllm.attention.backends.registry import AttentionBackendEnum
if use_sparse: if use_sparse:
raise NotImplementedError("Sparse Attention is not supported on XPU.") raise NotImplementedError("Sparse Attention is not supported on XPU.")
TRITON_ATTN = "vllm.v1.attention.backends.triton_attn.TritonAttentionBackend" # noqa: E501 use_v1 = envs.VLLM_USE_V1
FLASH_ATTN = "vllm.v1.attention.backends.flash_attn.FlashAttentionBackend" # noqa: E501 if not use_v1:
if selected_backend == _Backend.TRITON_ATTN: raise ValueError("XPU backend only supports V1.")
if selected_backend == AttentionBackendEnum.TRITON_ATTN:
logger.info_once("Using Triton backend.") logger.info_once("Using Triton backend.")
return TRITON_ATTN return AttentionBackendEnum.TRITON_ATTN.get_path()
elif selected_backend == _Backend.FLASH_ATTN: elif selected_backend == AttentionBackendEnum.FLASH_ATTN:
logger.info_once("Using Flash Attention backend.") logger.info_once("Using Flash Attention backend.")
return FLASH_ATTN return AttentionBackendEnum.FLASH_ATTN.get_path()
elif selected_backend: elif selected_backend:
raise ValueError( raise ValueError(
f"Invalid attention backend for {cls.device_name}, " f"Invalid attention backend for {cls.device_name}, "
@ -81,7 +81,7 @@ class XPUPlatform(Platform):
) )
logger.info("Using Flash Attention backend.") logger.info("Using Flash Attention backend.")
return "vllm.v1.attention.backends.flash_attn.FlashAttentionBackend" return AttentionBackendEnum.FLASH_ATTN.get_path()
@classmethod @classmethod
def set_device(cls, device: torch.device) -> None: def set_device(cls, device: torch.device) -> None:
@ -113,10 +113,10 @@ class XPUPlatform(Platform):
return device_props.total_memory return device_props.total_memory
@classmethod @classmethod
def get_vit_attn_backend(cls, head_size: int, dtype: torch.dtype) -> _Backend: def get_vit_attn_backend(
from vllm.attention.backends.registry import _Backend cls, head_size: int, dtype: torch.dtype
) -> AttentionBackendEnum:
return _Backend.FLASH_ATTN return AttentionBackendEnum.FLASH_ATTN
@classmethod @classmethod
def inference_mode(cls): def inference_mode(cls):

View File

@ -1,7 +1,7 @@
# SPDX-License-Identifier: Apache-2.0 # SPDX-License-Identifier: Apache-2.0
# SPDX-FileCopyrightText: Copyright contributors to the vLLM project # SPDX-FileCopyrightText: Copyright contributors to the vLLM project
from dataclasses import dataclass from dataclasses import dataclass
from typing import Optional from typing import ClassVar, Optional
import numpy as np import numpy as np
import torch import torch
@ -40,23 +40,16 @@ logger = init_logger(__name__)
class TorchSDPABackend(AttentionBackend): class TorchSDPABackend(AttentionBackend):
accept_output_buffer: bool = False accept_output_buffer: bool = False
supported_dtypes: ClassVar[list[torch.dtype]] = [
torch.float16,
torch.bfloat16,
torch.float32,
]
@classmethod @classmethod
def get_supported_dtypes(cls) -> list[torch.dtype]: def get_supported_head_sizes(cls) -> list[int]:
return [torch.float16, torch.bfloat16, torch.float32]
@classmethod
def validate_head_size(cls, head_size: int) -> None:
attn_impl = _get_paged_attn_impl() attn_impl = _get_paged_attn_impl()
is_valid, supported_head_sizes = attn_impl.validate_head_size(head_size) return attn_impl.get_supported_head_sizes()
if not is_valid:
attn_type = cls.__name__.removesuffix("Backend")
raise ValueError(
f"Head size {head_size} is not supported by {attn_type}. "
f"Supported head sizes are: {supported_head_sizes}. "
"Set VLLM_ATTENTION_BACKEND=FLEX_ATTENTION to use "
"FlexAttention backend which supports all head sizes."
)
@staticmethod @staticmethod
def get_name() -> str: def get_name() -> str:
@ -759,9 +752,8 @@ def _make_sliding_window_bias(
class _PagedAttention: class _PagedAttention:
@staticmethod @staticmethod
def validate_head_size(head_size: int) -> tuple[bool, list[int]]: def get_supported_head_sizes() -> list[int]:
SUPPORT_HS = [32, 64, 80, 96, 112, 128, 192, 256] return [32, 64, 80, 96, 112, 128, 192, 256]
return head_size in SUPPORT_HS, SUPPORT_HS
@staticmethod @staticmethod
def get_kv_cache_shape( def get_kv_cache_shape(
@ -861,8 +853,8 @@ class _PagedAttention:
class _IPEXPagedAttention(_PagedAttention): class _IPEXPagedAttention(_PagedAttention):
@staticmethod @staticmethod
def validate_head_size(head_size: int) -> tuple[bool, list[int]]: def get_supported_head_sizes() -> list[int]:
return True, [] return []
@staticmethod @staticmethod
def split_kv_cache( def split_kv_cache(

View File

@ -3,6 +3,7 @@
"""Attention layer with FlashAttention.""" """Attention layer with FlashAttention."""
from dataclasses import dataclass from dataclasses import dataclass
from typing import ClassVar
import numpy as np import numpy as np
import torch import torch
@ -32,11 +33,13 @@ if is_flash_attn_varlen_func_available():
reshape_and_cache_flash, reshape_and_cache_flash,
) )
from vllm.config import VllmConfig, get_layers_from_vllm_config from vllm.config import VllmConfig, get_layers_from_vllm_config
from vllm.config.cache import CacheDType
from vllm.distributed.parallel_state import get_dcp_group from vllm.distributed.parallel_state import get_dcp_group
from vllm.logger import init_logger from vllm.logger import init_logger
from vllm.model_executor.layers.batch_invariant import ( from vllm.model_executor.layers.batch_invariant import (
vllm_is_batch_invariant, vllm_is_batch_invariant,
) )
from vllm.platforms.interface import DeviceCapability
from vllm.utils.math_utils import cdiv from vllm.utils.math_utils import cdiv
from vllm.v1.attention.backends.utils import ( from vllm.v1.attention.backends.utils import (
AttentionCGSupport, AttentionCGSupport,
@ -52,34 +55,12 @@ logger = init_logger(__name__)
class FlashAttentionBackend(AttentionBackend): class FlashAttentionBackend(AttentionBackend):
accept_output_buffer: bool = True accept_output_buffer: bool = True
supported_dtypes: ClassVar[list[torch.dtype]] = [torch.float16, torch.bfloat16]
@classmethod # NOTE(tdoublep): while in principle, FA supports
def get_supported_dtypes(cls) -> list[torch.dtype]: # MultipleOf(16), these are the block sizes that do not
return [torch.float16, torch.bfloat16] # suffer from the NaN propagation problem described here:
# https://github.com/Dao-AILab/flash-attention/issues/1974
@classmethod supported_kernel_block_sizes: ClassVar[list[int | MultipleOf]] = [16, 32, 64]
def get_supported_head_sizes(cls) -> list[int]:
return [32, 64, 96, 128, 160, 192, 224, 256]
@staticmethod
def get_supported_kernel_block_size() -> list[int | MultipleOf]:
# NOTE(tdoublep): while in principle, FA supports
# MultipleOf(16), these are the block sizes that do not
# suffer from the NaN propagation problem described here:
# https://github.com/Dao-AILab/flash-attention/issues/1974
return [16, 32, 64]
@classmethod
def validate_head_size(cls, head_size: int) -> None:
supported_head_sizes = cls.get_supported_head_sizes()
if head_size not in supported_head_sizes:
attn_type = cls.__name__.removesuffix("Backend")
raise ValueError(
f"Head size {head_size} is not supported by {attn_type}. "
f"Supported head sizes are: {supported_head_sizes}. "
"Set VLLM_ATTENTION_BACKEND=FLEX_ATTENTION to use "
"FlexAttention backend which supports all head sizes."
)
@staticmethod @staticmethod
def get_name() -> str: def get_name() -> str:
@ -125,6 +106,38 @@ class FlashAttentionBackend(AttentionBackend):
else: else:
raise ValueError(f"Unrecognized FP8 dtype: {kv_cache_dtype}") raise ValueError(f"Unrecognized FP8 dtype: {kv_cache_dtype}")
@classmethod
def get_supported_head_sizes(cls) -> list[int]:
return [32, 64, 96, 128, 160, 192, 224, 256]
@classmethod
def supports_kv_cache_dtype(cls, kv_cache_dtype: CacheDType | None) -> bool:
if kv_cache_dtype is None:
return True
if kv_cache_dtype.startswith("fp8"):
return flash_attn_supports_fp8()
return kv_cache_dtype in ["auto"]
@classmethod
def supports_compute_capability(cls, capability: DeviceCapability) -> bool:
return capability >= DeviceCapability(8, 0)
@classmethod
def supports_combination(
cls,
head_size: int,
dtype: torch.dtype,
kv_cache_dtype: CacheDType | None,
block_size: int,
use_mla: bool,
has_sink: bool,
use_sparse: bool,
device_capability: DeviceCapability,
) -> str | None:
if has_sink and device_capability < DeviceCapability(9, 0):
return "sink not supported on compute capability < 9.0"
return None
@dataclass @dataclass
class FlashAttentionMetadata: class FlashAttentionMetadata:
@ -481,8 +494,6 @@ class FlashAttentionImpl(AttentionImpl):
self.num_queries_per_kv = self.num_heads // self.num_kv_heads self.num_queries_per_kv = self.num_heads // self.num_kv_heads
FlashAttentionBackend.validate_head_size(head_size)
self.attn_type = attn_type self.attn_type = attn_type
self.vllm_flash_attn_version = get_flash_attn_version() self.vllm_flash_attn_version = get_flash_attn_version()
# Cache the batch invariant result for use in forward passes # Cache the batch invariant result for use in forward passes

View File

@ -23,6 +23,7 @@ from vllm.attention.backends.abstract import (
MultipleOf, MultipleOf,
) )
from vllm.config import CUDAGraphMode, VllmConfig from vllm.config import CUDAGraphMode, VllmConfig
from vllm.config.cache import CacheDType
from vllm.logger import init_logger from vllm.logger import init_logger
from vllm.model_executor.layers.batch_invariant import ( from vllm.model_executor.layers.batch_invariant import (
vllm_is_batch_invariant, vllm_is_batch_invariant,
@ -33,6 +34,7 @@ from vllm.model_executor.layers.quantization.utils.quant_utils import (
kNvfp4Quant, kNvfp4Quant,
) )
from vllm.platforms import current_platform from vllm.platforms import current_platform
from vllm.platforms.interface import DeviceCapability
from vllm.triton_utils import tl, triton from vllm.triton_utils import tl, triton
from vllm.utils.flashinfer import ( from vllm.utils.flashinfer import (
can_use_trtllm_attention, can_use_trtllm_attention,
@ -45,6 +47,7 @@ from vllm.v1.attention.backends.utils import (
AttentionCGSupport, AttentionCGSupport,
AttentionMetadataBuilder, AttentionMetadataBuilder,
CommonAttentionMetadata, CommonAttentionMetadata,
KVCacheLayoutType,
get_kv_cache_layout, get_kv_cache_layout,
get_per_layer_parameters, get_per_layer_parameters,
infer_global_hyperparameters, infer_global_hyperparameters,
@ -158,34 +161,17 @@ def trtllm_prefill_attn_kvfp8_dequant(
class FlashInferBackend(AttentionBackend): class FlashInferBackend(AttentionBackend):
accept_output_buffer: bool = True accept_output_buffer: bool = True
supported_dtypes: ClassVar[list[torch.dtype]] = [torch.float16, torch.bfloat16]
@classmethod # Note: Not sure for all platforms,
def get_supported_dtypes(cls) -> list[torch.dtype]: # but on Blackwell, only support a page size of
return [torch.float16, torch.bfloat16] # 16, 32, 64
supported_kernel_block_sizes: ClassVar[list[int | MultipleOf]] = [16, 32, 64]
@classmethod supported_kv_cache_dtypes: ClassVar[list[CacheDType]] = [
def get_supported_head_sizes(cls) -> list[int]: "auto",
# https://github.com/flashinfer-ai/flashinfer/blob/3d55c71a62052c590c130897d3a3db49b14fcc34/include/flashinfer/utils.cuh#L157 "fp8",
return [64, 128, 256] "fp8_e4m3",
"fp8_e5m2",
@staticmethod ]
def get_supported_kernel_block_size() -> list[int | MultipleOf]:
# Note: Not sure for all platforms,
# but on Blackwell, only support a page size of
# 16, 32, 64
return [16, 32, 64]
@classmethod
def validate_head_size(cls, head_size: int) -> None:
supported_head_sizes = cls.get_supported_head_sizes()
if head_size not in supported_head_sizes:
attn_type = cls.__name__.removesuffix("Backend")
raise ValueError(
f"Head size {head_size} is not supported by {attn_type}. "
f"Supported head sizes are: {supported_head_sizes}. "
"Set VLLM_ATTENTION_BACKEND=FLEX_ATTENTION to use "
"FlexAttention backend which supports all head sizes."
)
@staticmethod @staticmethod
def get_name() -> str: def get_name() -> str:
@ -231,6 +217,26 @@ class FlashInferBackend(AttentionBackend):
else: else:
raise ValueError(f"Unrecognized FP8 dtype: {kv_cache_dtype}") raise ValueError(f"Unrecognized FP8 dtype: {kv_cache_dtype}")
@classmethod
def get_supported_head_sizes(cls) -> list[int]:
# https://github.com/flashinfer-ai/flashinfer/blob/3d55c71a62052c590c130897d3a3db49b14fcc34/include/flashinfer/utils.cuh#L157
return [64, 128, 256]
@classmethod
def supports_compute_capability(cls, capability: DeviceCapability) -> bool:
return capability >= DeviceCapability(7, 5) and capability <= DeviceCapability(
12, 1
)
@classmethod
def get_required_kv_cache_layout(cls) -> KVCacheLayoutType | None:
from vllm.platforms import current_platform
capability = current_platform.get_device_capability()
if capability is not None and capability.major == 10:
return "HND"
return None
@dataclass @dataclass
class FlashInferMetadata: class FlashInferMetadata:
@ -328,7 +334,6 @@ class FlashInferMetadataBuilder(AttentionMetadataBuilder[FlashInferMetadata]):
) )
self.num_kv_heads = self.kv_cache_spec.num_kv_heads self.num_kv_heads = self.kv_cache_spec.num_kv_heads
self.head_dim = self.kv_cache_spec.head_size self.head_dim = self.kv_cache_spec.head_size
FlashInferBackend.validate_head_size(self.head_dim)
self.page_size = self.kv_cache_spec.block_size self.page_size = self.kv_cache_spec.block_size
self.cache_dtype = self.cache_config.cache_dtype self.cache_dtype = self.cache_config.cache_dtype

View File

@ -4,6 +4,7 @@
import math import math
from dataclasses import dataclass from dataclasses import dataclass
from typing import ClassVar
import torch import torch
import torch._dynamo.decorators import torch._dynamo.decorators
@ -24,6 +25,7 @@ from vllm.attention.backends.abstract import (
is_quantized_kv_cache, is_quantized_kv_cache,
) )
from vllm.config import VllmConfig from vllm.config import VllmConfig
from vllm.config.cache import CacheDType
from vllm.logger import init_logger from vllm.logger import init_logger
from vllm.model_executor.layers.batch_invariant import ( from vllm.model_executor.layers.batch_invariant import (
vllm_is_batch_invariant, vllm_is_batch_invariant,
@ -71,14 +73,12 @@ def pad_to_multiple(x: torch.Tensor, multiple: int, dim: int):
class FlexAttentionBackend(AttentionBackend): class FlexAttentionBackend(AttentionBackend):
accept_output_buffer: bool = True accept_output_buffer: bool = True
supported_dtypes: ClassVar[list[torch.dtype]] = [
@classmethod torch.float16,
def get_supported_dtypes(cls) -> list[torch.dtype]: torch.bfloat16,
return [torch.float16, torch.bfloat16, torch.float32] torch.float32,
]
@classmethod supported_kv_cache_dtypes: ClassVar[list[CacheDType]] = ["auto"]
def validate_head_size(cls, head_size: int) -> None:
return # FlexAttention supports any head size
@staticmethod @staticmethod
def get_name() -> str: def get_name() -> str:
@ -106,6 +106,10 @@ class FlexAttentionBackend(AttentionBackend):
def use_cascade_attention(*args, **kwargs) -> bool: def use_cascade_attention(*args, **kwargs) -> bool:
return False return False
@classmethod
def get_supported_head_sizes(cls) -> list[int]:
return []
# @torch.compile(fullgraph=True, mode="reduce-overhead") # @torch.compile(fullgraph=True, mode="reduce-overhead")
def physical_to_logical_mapping( def physical_to_logical_mapping(
@ -720,7 +724,6 @@ class FlexAttentionImpl(AttentionImpl):
if kv_sharing_target_layer_name is not None: if kv_sharing_target_layer_name is not None:
raise NotImplementedError("FlexAttention does not support kv sharing yet.") raise NotImplementedError("FlexAttention does not support kv sharing yet.")
FlexAttentionBackend.validate_head_size(head_size)
if is_quantized_kv_cache(self.kv_cache_dtype): if is_quantized_kv_cache(self.kv_cache_dtype):
raise NotImplementedError( raise NotImplementedError(
"FlexAttention does not support quantized kv-cache. Yet" "FlexAttention does not support quantized kv-cache. Yet"

View File

@ -308,25 +308,13 @@ class MLACommonBackend(AttentionBackend):
) -> tuple[int, ...]: ) -> tuple[int, ...]:
return (num_blocks, block_size, head_size) return (num_blocks, block_size, head_size)
@classmethod
def get_supported_dtypes(cls) -> list[torch.dtype]:
return [torch.float16, torch.bfloat16]
@classmethod @classmethod
def get_supported_head_sizes(cls) -> list[int]: def get_supported_head_sizes(cls) -> list[int]:
return [576] return [576]
@classmethod @classmethod
def validate_head_size(cls, head_size: int) -> None: def is_mla(cls) -> bool:
supported_head_sizes = cls.get_supported_head_sizes() return True
if head_size not in supported_head_sizes:
attn_type = cls.__name__.removesuffix("Backend")
raise ValueError(
f"Head size {head_size} is not supported by {attn_type}. "
f"Supported head sizes are: {supported_head_sizes}. "
"Set VLLM_ATTENTION_BACKEND=FLEX_ATTENTION to use "
"FlexAttention backend which supports all head sizes."
)
@dataclass @dataclass
@ -425,8 +413,10 @@ class MLACommonMetadata(Generic[D]):
) = None ) = None
def __post_init__(self): def __post_init__(self):
if self.head_dim is not None: if self.head_dim is not None and not MLACommonBackend.supports_head_size(
MLACommonBackend.validate_head_size(self.head_dim) self.head_dim
):
raise ValueError(f"Head dimension {self.head_dim} is not supported by MLA.")
M = TypeVar("M", bound=MLACommonMetadata) M = TypeVar("M", bound=MLACommonMetadata)

View File

@ -13,7 +13,9 @@ from vllm.attention.backends.abstract import (
MultipleOf, MultipleOf,
is_quantized_kv_cache, is_quantized_kv_cache,
) )
from vllm.config.cache import CacheDType
from vllm.logger import init_logger from vllm.logger import init_logger
from vllm.platforms.interface import DeviceCapability
from vllm.v1.attention.backends.mla.common import ( from vllm.v1.attention.backends.mla.common import (
MLACommonBackend, MLACommonBackend,
MLACommonImpl, MLACommonImpl,
@ -33,6 +35,14 @@ class CutlassMLAMetadataBuilder(MLACommonMetadataBuilder[MLACommonMetadata]):
class CutlassMLABackend(MLACommonBackend): class CutlassMLABackend(MLACommonBackend):
supported_dtypes: ClassVar[list[torch.dtype]] = [torch.float16, torch.bfloat16]
supported_kernel_block_sizes: ClassVar[list[int | MultipleOf]] = [128]
supported_kv_cache_dtypes: ClassVar[list[CacheDType]] = [
"auto",
"fp8",
"fp8_e4m3",
]
@staticmethod @staticmethod
def get_name() -> str: def get_name() -> str:
return "CUTLASS_MLA" return "CUTLASS_MLA"
@ -45,9 +55,9 @@ class CutlassMLABackend(MLACommonBackend):
def get_builder_cls() -> type["CutlassMLAMetadataBuilder"]: def get_builder_cls() -> type["CutlassMLAMetadataBuilder"]:
return CutlassMLAMetadataBuilder return CutlassMLAMetadataBuilder
@staticmethod @classmethod
def get_supported_kernel_block_size() -> list[int | MultipleOf]: def supports_compute_capability(cls, capability: DeviceCapability) -> bool:
return [128] return capability.major == 10
class SM100Workspace: class SM100Workspace:

View File

@ -10,6 +10,7 @@ from vllm import envs
from vllm.attention.backends.abstract import ( from vllm.attention.backends.abstract import (
AttentionLayer, AttentionLayer,
AttentionType, AttentionType,
MultipleOf,
is_quantized_kv_cache, is_quantized_kv_cache,
) )
from vllm.attention.utils.fa_utils import ( from vllm.attention.utils.fa_utils import (
@ -17,10 +18,12 @@ from vllm.attention.utils.fa_utils import (
get_flash_attn_version, get_flash_attn_version,
) )
from vllm.config import VllmConfig from vllm.config import VllmConfig
from vllm.config.cache import CacheDType
from vllm.logger import init_logger from vllm.logger import init_logger
from vllm.model_executor.layers.batch_invariant import ( from vllm.model_executor.layers.batch_invariant import (
vllm_is_batch_invariant, vllm_is_batch_invariant,
) )
from vllm.platforms.interface import DeviceCapability
from vllm.v1.attention.backends.mla.common import ( from vllm.v1.attention.backends.mla.common import (
MLACommonBackend, MLACommonBackend,
MLACommonDecodeMetadata, MLACommonDecodeMetadata,
@ -37,6 +40,10 @@ logger = init_logger(__name__)
class FlashAttnMLABackend(MLACommonBackend): class FlashAttnMLABackend(MLACommonBackend):
supported_dtypes: ClassVar[list[torch.dtype]] = [torch.float16, torch.bfloat16]
supported_kernel_block_sizes: ClassVar[list[int | MultipleOf]] = [MultipleOf(16)]
supported_kv_cache_dtypes: ClassVar[list[CacheDType]] = ["auto"]
@staticmethod @staticmethod
def get_name() -> str: def get_name() -> str:
return "FLASH_ATTN_MLA" return "FLASH_ATTN_MLA"
@ -49,6 +56,26 @@ class FlashAttnMLABackend(MLACommonBackend):
def get_impl_cls() -> type["FlashAttnMLAImpl"]: def get_impl_cls() -> type["FlashAttnMLAImpl"]:
return FlashAttnMLAImpl return FlashAttnMLAImpl
@classmethod
def supports_compute_capability(cls, capability: DeviceCapability) -> bool:
return capability.major == 9
@classmethod
def supports_combination(
cls,
head_size: int,
dtype: torch.dtype,
kv_cache_dtype: CacheDType | None,
block_size: int,
use_mla: bool,
has_sink: bool,
use_sparse: bool,
device_capability: DeviceCapability,
) -> str | None:
if not flash_attn_supports_mla():
return "FlashAttention MLA not supported on this device"
return None
@dataclass @dataclass
class FlashAttnMLADecodeMetadata(MLACommonDecodeMetadata): class FlashAttnMLADecodeMetadata(MLACommonDecodeMetadata):

View File

@ -6,8 +6,14 @@ from typing import ClassVar
import torch import torch
from flashinfer.decode import trtllm_batch_decode_with_kv_cache_mla from flashinfer.decode import trtllm_batch_decode_with_kv_cache_mla
from vllm.attention.backends.abstract import AttentionLayer, AttentionType, MultipleOf from vllm.attention.backends.abstract import (
AttentionLayer,
AttentionType,
MultipleOf,
)
from vllm.config.cache import CacheDType
from vllm.logger import init_logger from vllm.logger import init_logger
from vllm.platforms.interface import DeviceCapability
from vllm.v1.attention.backends.mla.common import ( from vllm.v1.attention.backends.mla.common import (
MLACommonBackend, MLACommonBackend,
MLACommonImpl, MLACommonImpl,
@ -15,7 +21,7 @@ from vllm.v1.attention.backends.mla.common import (
MLACommonMetadataBuilder, MLACommonMetadataBuilder,
QueryLenSupport, QueryLenSupport,
) )
from vllm.v1.attention.backends.utils import AttentionCGSupport from vllm.v1.attention.backends.utils import AttentionCGSupport, KVCacheLayoutType
logger = init_logger(__name__) logger = init_logger(__name__)
@ -28,6 +34,14 @@ class FlashInferMLAMetadataBuilder(MLACommonMetadataBuilder[MLACommonMetadata]):
class FlashInferMLABackend(MLACommonBackend): class FlashInferMLABackend(MLACommonBackend):
supported_dtypes: ClassVar[list[torch.dtype]] = [torch.float16, torch.bfloat16]
supported_kernel_block_sizes: ClassVar[list[int | MultipleOf]] = [32, 64]
supported_kv_cache_dtypes: ClassVar[list[CacheDType]] = [
"auto",
"fp8",
"fp8_e4m3",
]
@staticmethod @staticmethod
def get_name() -> str: def get_name() -> str:
return "FLASHINFER_MLA" return "FLASHINFER_MLA"
@ -41,8 +55,12 @@ class FlashInferMLABackend(MLACommonBackend):
return FlashInferMLAMetadataBuilder return FlashInferMLAMetadataBuilder
@classmethod @classmethod
def get_supported_kernel_block_size(cls) -> list[int | MultipleOf]: def supports_compute_capability(cls, capability: DeviceCapability) -> bool:
return [32, 64] return capability.major == 10
@classmethod
def get_required_kv_cache_layout(cls) -> "KVCacheLayoutType | None":
return "HND"
g_fi_workspace = torch.zeros( g_fi_workspace = torch.zeros(

View File

@ -13,10 +13,12 @@ from vllm.attention.ops.flashmla import (
is_flashmla_dense_supported, is_flashmla_dense_supported,
) )
from vllm.config import VllmConfig from vllm.config import VllmConfig
from vllm.config.cache import CacheDType
from vllm.logger import init_logger from vllm.logger import init_logger
from vllm.model_executor.layers.batch_invariant import ( from vllm.model_executor.layers.batch_invariant import (
vllm_is_batch_invariant, vllm_is_batch_invariant,
) )
from vllm.platforms.interface import DeviceCapability
from vllm.v1.attention.backends.mla.common import ( from vllm.v1.attention.backends.mla.common import (
MLACommonBackend, MLACommonBackend,
MLACommonDecodeMetadata, MLACommonDecodeMetadata,
@ -36,6 +38,14 @@ logger = init_logger(__name__)
class FlashMLABackend(MLACommonBackend): class FlashMLABackend(MLACommonBackend):
supported_dtypes: ClassVar[list[torch.dtype]] = [torch.float16, torch.bfloat16]
supported_kernel_block_sizes: ClassVar[list[int | MultipleOf]] = [64]
supported_kv_cache_dtypes: ClassVar[list[CacheDType]] = [
"auto",
"fp8",
"fp8_e4m3",
]
@staticmethod @staticmethod
def get_name() -> str: def get_name() -> str:
return "FLASHMLA" return "FLASHMLA"
@ -48,9 +58,30 @@ class FlashMLABackend(MLACommonBackend):
def get_impl_cls() -> type["FlashMLAImpl"]: def get_impl_cls() -> type["FlashMLAImpl"]:
return FlashMLAImpl return FlashMLAImpl
@staticmethod @classmethod
def get_supported_kernel_block_size() -> list[int | MultipleOf]: def supports_compute_capability(cls, capability: DeviceCapability) -> bool:
return [64] return capability.major in [9, 10]
@classmethod
def supports_combination(
cls,
head_size: int,
dtype: torch.dtype,
kv_cache_dtype: CacheDType | None,
block_size: int,
use_mla: bool,
has_sink: bool,
use_sparse: bool,
device_capability: DeviceCapability,
) -> str | None:
if use_sparse:
from vllm.attention.ops.flashmla import is_flashmla_sparse_supported
return is_flashmla_sparse_supported()[1]
else:
from vllm.attention.ops.flashmla import is_flashmla_dense_supported
return is_flashmla_dense_supported()[1]
@dataclass @dataclass

View File

@ -10,6 +10,7 @@ from vllm import _custom_ops as ops
from vllm.attention.backends.abstract import ( from vllm.attention.backends.abstract import (
AttentionBackend, AttentionBackend,
AttentionLayer, AttentionLayer,
MultipleOf,
) )
from vllm.attention.backends.utils import get_mla_dims from vllm.attention.backends.utils import get_mla_dims
from vllm.attention.ops.flashmla import ( from vllm.attention.ops.flashmla import (
@ -18,8 +19,10 @@ from vllm.attention.ops.flashmla import (
get_mla_metadata, get_mla_metadata,
) )
from vllm.config import VllmConfig from vllm.config import VllmConfig
from vllm.config.cache import CacheDType
from vllm.logger import init_logger from vllm.logger import init_logger
from vllm.platforms import current_platform from vllm.platforms import current_platform
from vllm.platforms.interface import DeviceCapability
from vllm.triton_utils import tl, triton from vllm.triton_utils import tl, triton
from vllm.utils.math_utils import cdiv from vllm.utils.math_utils import cdiv
from vllm.v1.attention.backends.mla.common import MLACommonBaseImpl from vllm.v1.attention.backends.mla.common import MLACommonBaseImpl
@ -51,6 +54,9 @@ structured as:
class FlashMLASparseBackend(AttentionBackend): class FlashMLASparseBackend(AttentionBackend):
accept_output_buffer: bool = True accept_output_buffer: bool = True
supported_dtypes: ClassVar[list[torch.dtype]] = [torch.bfloat16]
supported_kernel_block_sizes: ClassVar[list[int | MultipleOf]] = [64]
supported_kv_cache_dtypes: ClassVar[list[CacheDType]] = ["auto", "fp8_ds_mla"]
@staticmethod @staticmethod
def get_name() -> str: def get_name() -> str:
@ -64,6 +70,22 @@ class FlashMLASparseBackend(AttentionBackend):
def get_impl_cls() -> type["FlashMLASparseImpl"]: def get_impl_cls() -> type["FlashMLASparseImpl"]:
return FlashMLASparseImpl return FlashMLASparseImpl
@classmethod
def get_supported_head_sizes(cls) -> list[int]:
return [576]
@classmethod
def is_mla(cls) -> bool:
return True
@classmethod
def is_sparse(cls) -> bool:
return True
@classmethod
def supports_compute_capability(cls, capability: DeviceCapability) -> bool:
return capability.major in [9, 10]
@staticmethod @staticmethod
def get_kv_cache_shape( def get_kv_cache_shape(
num_blocks: int, num_blocks: int,
@ -79,14 +101,6 @@ class FlashMLASparseBackend(AttentionBackend):
else: else:
return (num_blocks, block_size, head_size) return (num_blocks, block_size, head_size)
@classmethod
def get_supported_dtypes(cls) -> list[torch.dtype]:
return [torch.bfloat16]
@classmethod
def get_supported_head_sizes(cls) -> list[int]:
return [576]
@dataclass @dataclass
class FlashMLASparseMetadata: class FlashMLASparseMetadata:

View File

@ -23,6 +23,8 @@ logger = init_logger(__name__)
class DeepseekV32IndexerBackend(AttentionBackend): class DeepseekV32IndexerBackend(AttentionBackend):
supported_kernel_block_sizes: ClassVar[list[int | MultipleOf]] = [64]
@classmethod @classmethod
def get_supported_head_sizes(cls) -> list[int]: def get_supported_head_sizes(cls) -> list[int]:
return [32, 64, 128] return [32, 64, 128]
@ -46,10 +48,6 @@ class DeepseekV32IndexerBackend(AttentionBackend):
def get_kv_cache_stride_order() -> tuple[int, ...]: def get_kv_cache_stride_order() -> tuple[int, ...]:
return (0, 1, 2) return (0, 1, 2)
@classmethod
def get_supported_kernel_block_size(cls) -> list[int | MultipleOf]:
return [64]
@dataclass @dataclass
class DeepseekV32IndexerPrefillChunkMetadata: class DeepseekV32IndexerPrefillChunkMetadata:

View File

@ -1,6 +1,7 @@
# SPDX-License-Identifier: Apache-2.0 # SPDX-License-Identifier: Apache-2.0
# SPDX-FileCopyrightText: Copyright contributors to the vLLM project # SPDX-FileCopyrightText: Copyright contributors to the vLLM project
from typing import ClassVar
import torch import torch
@ -12,11 +13,13 @@ from vllm.attention.backends.abstract import (
) )
from vllm.attention.ops.triton_decode_attention import decode_attention_fwd from vllm.attention.ops.triton_decode_attention import decode_attention_fwd
from vllm.attention.ops.triton_flash_attention import triton_attention from vllm.attention.ops.triton_flash_attention import triton_attention
from vllm.config.cache import CacheDType
from vllm.logger import init_logger from vllm.logger import init_logger
from vllm.model_executor.layers.batch_invariant import ( from vllm.model_executor.layers.batch_invariant import (
vllm_is_batch_invariant, vllm_is_batch_invariant,
) )
from vllm.platforms import current_platform from vllm.platforms import current_platform
from vllm.platforms.interface import DeviceCapability
from vllm.triton_utils import HAS_TRITON from vllm.triton_utils import HAS_TRITON
from vllm.v1.attention.backends.mla.common import ( from vllm.v1.attention.backends.mla.common import (
MLACommonBackend, MLACommonBackend,
@ -28,6 +31,9 @@ logger = init_logger(__name__)
class TritonMLABackend(MLACommonBackend): class TritonMLABackend(MLACommonBackend):
supported_dtypes: ClassVar[list[torch.dtype]] = [torch.float16, torch.bfloat16]
supported_kv_cache_dtypes: ClassVar[list[CacheDType]] = ["auto"]
@staticmethod @staticmethod
def get_name() -> str: def get_name() -> str:
return "TRITON_MLA" return "TRITON_MLA"
@ -36,6 +42,10 @@ class TritonMLABackend(MLACommonBackend):
def get_impl_cls() -> type["TritonMLAImpl"]: def get_impl_cls() -> type["TritonMLAImpl"]:
return TritonMLAImpl return TritonMLAImpl
@classmethod
def supports_compute_capability(cls, capability: DeviceCapability) -> bool:
return True
class TritonMLAImpl(MLACommonImpl[MLACommonMetadata]): class TritonMLAImpl(MLACommonImpl[MLACommonMetadata]):
can_return_lse_for_decode: bool = True can_return_lse_for_decode: bool = True

View File

@ -3,6 +3,7 @@
"""Attention layer with AiterFlashAttention.""" """Attention layer with AiterFlashAttention."""
from dataclasses import dataclass from dataclasses import dataclass
from typing import ClassVar
import torch import torch
@ -445,31 +446,13 @@ class AiterFlashAttentionMetadataBuilder(
class AiterFlashAttentionBackend(AttentionBackend): class AiterFlashAttentionBackend(AttentionBackend):
accept_output_buffer: bool = True accept_output_buffer: bool = True
supported_dtypes: ClassVar[list[torch.dtype]] = [torch.float16, torch.bfloat16]
@classmethod supported_kernel_block_sizes: ClassVar[list[int | MultipleOf]] = [MultipleOf(16)]
def get_supported_dtypes(cls) -> list[torch.dtype]:
return [torch.float16, torch.bfloat16]
@classmethod @classmethod
def get_supported_head_sizes(cls) -> list[int]: def get_supported_head_sizes(cls) -> list[int]:
return [64, 128, 256] return [64, 128, 256]
@staticmethod
def get_supported_kernel_block_size() -> list[int | MultipleOf]:
return [MultipleOf(16)]
@classmethod
def validate_head_size(cls, head_size: int) -> None:
supported_head_sizes = cls.get_supported_head_sizes()
if head_size not in supported_head_sizes:
attn_type = cls.__name__.removesuffix("Backend")
raise ValueError(
f"Head size {head_size} is not supported by {attn_type}. "
f"Supported head sizes are: {supported_head_sizes}. "
"Set VLLM_ATTENTION_BACKEND=FLEX_ATTENTION to use "
"FlexAttention backend which supports all head sizes."
)
@staticmethod @staticmethod
def get_name() -> str: def get_name() -> str:
return "FLASH_ATTN" return "FLASH_ATTN"
@ -531,8 +514,6 @@ class AiterFlashAttentionImpl(AttentionImpl):
assert self.num_heads % self.num_kv_heads == 0 assert self.num_heads % self.num_kv_heads == 0
self.num_queries_per_kv = self.num_heads // self.num_kv_heads self.num_queries_per_kv = self.num_heads // self.num_kv_heads
AiterFlashAttentionBackend.validate_head_size(head_size)
if attn_type != AttentionType.DECODER: if attn_type != AttentionType.DECODER:
raise NotImplementedError( raise NotImplementedError(
"Encoder self-attention and " "Encoder self-attention and "

View File

@ -152,10 +152,7 @@ class RocmAttentionMetadataBuilder(AttentionMetadataBuilder[RocmAttentionMetadat
class RocmAttentionBackend(AttentionBackend): class RocmAttentionBackend(AttentionBackend):
accept_output_buffer: bool = True accept_output_buffer: bool = True
supported_dtypes: ClassVar[list[torch.dtype]] = [torch.float16, torch.bfloat16]
@classmethod
def get_supported_dtypes(cls) -> list[torch.dtype]:
return [torch.float16, torch.bfloat16]
@classmethod @classmethod
def get_supported_head_sizes(cls) -> list[int]: def get_supported_head_sizes(cls) -> list[int]:
@ -163,12 +160,11 @@ class RocmAttentionBackend(AttentionBackend):
@classmethod @classmethod
def validate_head_size(cls, head_size: int) -> None: def validate_head_size(cls, head_size: int) -> None:
supported_head_sizes = cls.get_supported_head_sizes() if not cls.supports_head_size(head_size):
if head_size not in supported_head_sizes:
attn_type = cls.__name__.removesuffix("Backend") attn_type = cls.__name__.removesuffix("Backend")
raise ValueError( raise ValueError(
f"Head size {head_size} is not supported by {attn_type}. " f"Head size {head_size} is not supported by {attn_type}. "
f"Supported head sizes are: {supported_head_sizes}. " f"Supported head sizes are: {cls.get_supported_head_sizes()}. "
"Set VLLM_ATTENTION_BACKEND=FLEX_ATTENTION to use " "Set VLLM_ATTENTION_BACKEND=FLEX_ATTENTION to use "
"FlexAttention backend which supports all head sizes." "FlexAttention backend which supports all head sizes."
) )

View File

@ -4,7 +4,7 @@
import ast import ast
from dataclasses import dataclass from dataclasses import dataclass
from typing import Optional from typing import ClassVar, Optional
import torch import torch
@ -30,31 +30,13 @@ logger = init_logger(__name__)
class TreeAttentionBackend(AttentionBackend): class TreeAttentionBackend(AttentionBackend):
accept_output_buffer: bool = True accept_output_buffer: bool = True
supported_dtypes: ClassVar[list[torch.dtype]] = [torch.float16, torch.bfloat16]
@classmethod supported_kernel_block_sizes: ClassVar[list[int | MultipleOf]] = [MultipleOf(16)]
def get_supported_dtypes(cls) -> list[torch.dtype]:
return [torch.float16, torch.bfloat16]
@classmethod @classmethod
def get_supported_head_sizes(cls) -> list[int]: def get_supported_head_sizes(cls) -> list[int]:
return [32, 64, 96, 128, 160, 192, 224, 256] return [32, 64, 96, 128, 160, 192, 224, 256]
@staticmethod
def get_supported_kernel_block_size() -> list[int | MultipleOf]:
return [MultipleOf(16)]
@classmethod
def validate_head_size(cls, head_size: int) -> None:
supported_head_sizes = cls.get_supported_head_sizes()
if head_size not in supported_head_sizes:
attn_type = cls.__name__.removesuffix("Backend")
raise ValueError(
f"Head size {head_size} is not supported by {attn_type}. "
f"Supported head sizes are: {supported_head_sizes}. "
"Set VLLM_ATTENTION_BACKEND=FLEX_ATTENTION to use "
"FlexAttention backend which supports all head sizes."
)
@staticmethod @staticmethod
def get_name() -> str: def get_name() -> str:
return "TREE_ATTN" return "TREE_ATTN"
@ -331,8 +313,6 @@ class TreeAttentionImpl(AttentionImpl):
else: else:
self.sliding_window = (sliding_window - 1, 0) self.sliding_window = (sliding_window - 1, 0)
TreeAttentionBackend.validate_head_size(head_size)
if attn_type != AttentionType.DECODER: if attn_type != AttentionType.DECODER:
raise NotImplementedError( raise NotImplementedError(
"Encoder self-attention and " "Encoder self-attention and "

View File

@ -18,12 +18,14 @@ from vllm.attention.ops.triton_reshape_and_cache_flash import (
) )
from vllm.attention.ops.triton_unified_attention import unified_attention from vllm.attention.ops.triton_unified_attention import unified_attention
from vllm.config import VllmConfig from vllm.config import VllmConfig
from vllm.config.cache import CacheDType
from vllm.logger import init_logger from vllm.logger import init_logger
from vllm.model_executor.layers.quantization.utils.quant_utils import ( from vllm.model_executor.layers.quantization.utils.quant_utils import (
QuantKey, QuantKey,
kFp8StaticTensorSym, kFp8StaticTensorSym,
) )
from vllm.platforms import current_platform from vllm.platforms import current_platform
from vllm.platforms.interface import DeviceCapability
from vllm.v1.attention.backends.utils import ( from vllm.v1.attention.backends.utils import (
AttentionCGSupport, AttentionCGSupport,
AttentionMetadataBuilder, AttentionMetadataBuilder,
@ -147,25 +149,18 @@ class TritonAttentionMetadataBuilder(AttentionMetadataBuilder[TritonAttentionMet
class TritonAttentionBackend(AttentionBackend): class TritonAttentionBackend(AttentionBackend):
accept_output_buffer: bool = True accept_output_buffer: bool = True
supported_dtypes: ClassVar[list[torch.dtype]] = [
@classmethod torch.float16,
def get_supported_dtypes(cls) -> list[torch.dtype]: torch.bfloat16,
return [torch.float16, torch.bfloat16, torch.float32] torch.float32,
]
@staticmethod supported_kernel_block_sizes: ClassVar[list[int | MultipleOf]] = [MultipleOf(16)]
def get_supported_kernel_block_size() -> list[int | MultipleOf]: supported_kv_cache_dtypes: ClassVar[list[CacheDType]] = [
return [MultipleOf(16)] "auto",
"fp8",
@classmethod "fp8_e4m3",
def validate_head_size(cls, head_size: int) -> None: "fp8_e5m2",
# Triton Attention supports any head size above 32 ]
if head_size < 32:
raise ValueError(
f"Head size {head_size} is not supported by TritonAttention."
f"Head sizes need to be larger or equal 32 for this backend. "
"Set VLLM_ATTENTION_BACKEND=FLEX_ATTENTION to use "
"FlexAttention backend which supports all head sizes."
)
@staticmethod @staticmethod
def get_name() -> str: def get_name() -> str:
@ -195,6 +190,18 @@ class TritonAttentionBackend(AttentionBackend):
def get_builder_cls() -> type["TritonAttentionMetadataBuilder"]: def get_builder_cls() -> type["TritonAttentionMetadataBuilder"]:
return TritonAttentionMetadataBuilder return TritonAttentionMetadataBuilder
@classmethod
def supports_head_size(cls, head_size: int) -> bool:
return head_size >= 32
@classmethod
def supports_sink(cls) -> bool:
return True
@classmethod
def supports_compute_capability(cls, capability: DeviceCapability) -> bool:
return True
class TritonAttentionImpl(AttentionImpl): class TritonAttentionImpl(AttentionImpl):
def fused_output_quant_supported(self, quant_key: QuantKey): def fused_output_quant_supported(self, quant_key: QuantKey):
@ -237,8 +244,6 @@ class TritonAttentionImpl(AttentionImpl):
self.num_queries_per_kv = self.num_heads // self.num_kv_heads self.num_queries_per_kv = self.num_heads // self.num_kv_heads
TritonAttentionBackend.validate_head_size(head_size)
if attn_type != AttentionType.DECODER: if attn_type != AttentionType.DECODER:
raise NotImplementedError( raise NotImplementedError(
"Encoder self-attention and " "Encoder self-attention and "

View File

@ -3,7 +3,7 @@
"""Attention layer with XFormersAttention.""" """Attention layer with XFormersAttention."""
from dataclasses import dataclass from dataclasses import dataclass
from typing import Optional from typing import ClassVar, Optional
import torch import torch
@ -41,10 +41,8 @@ logger = init_logger(__name__)
class XFormersAttentionBackend(AttentionBackend): class XFormersAttentionBackend(AttentionBackend):
accept_output_buffer: bool = True accept_output_buffer: bool = True
supported_dtypes: ClassVar[list[torch.dtype]] = [torch.float16, torch.bfloat16]
@classmethod supported_kernel_block_sizes: ClassVar[list[int | MultipleOf]] = [MultipleOf(16)]
def get_supported_dtypes(cls) -> list[torch.dtype]:
return [torch.float16, torch.bfloat16]
@classmethod @classmethod
def get_supported_head_sizes(cls) -> list[int]: def get_supported_head_sizes(cls) -> list[int]:
@ -80,22 +78,6 @@ class XFormersAttentionBackend(AttentionBackend):
256, 256,
] ]
@staticmethod
def get_supported_kernel_block_size() -> list[int | MultipleOf]:
return [MultipleOf(16)]
@classmethod
def validate_head_size(cls, head_size: int) -> None:
supported_head_sizes = cls.get_supported_head_sizes()
if head_size not in supported_head_sizes:
attn_type = cls.__name__.removesuffix("Backend")
raise ValueError(
f"Head size {head_size} is not supported by {attn_type}. "
f"Supported head sizes are: {supported_head_sizes}. "
"Set VLLM_ATTENTION_BACKEND=FLEX_ATTENTION to use "
"FlexAttention backend which supports all head sizes."
)
@staticmethod @staticmethod
def get_name() -> str: def get_name() -> str:
return "XFORMERS" return "XFORMERS"
@ -305,8 +287,6 @@ class XFormersAttentionImpl(AttentionImpl):
logits_soft_cap = 0 logits_soft_cap = 0
self.logits_soft_cap = logits_soft_cap self.logits_soft_cap = logits_soft_cap
XFormersAttentionBackend.validate_head_size(head_size)
if attn_type != AttentionType.DECODER: if attn_type != AttentionType.DECODER:
raise NotImplementedError( raise NotImplementedError(
"Encoder self-attention and " "Encoder self-attention and "

View File

@ -150,11 +150,15 @@ class EagleProposer:
) )
# Determine allowed attention backends once during initialization. # Determine allowed attention backends once during initialization.
from vllm.attention.backends.registry import AttentionBackendEnum
self.allowed_attn_types: tuple | None = None self.allowed_attn_types: tuple | None = None
if current_platform.is_rocm(): if current_platform.is_rocm():
rocm_types = [TritonAttentionMetadata, FlashAttentionMetadata] rocm_types = [TritonAttentionMetadata, FlashAttentionMetadata]
# vllm.v1.attention.backends.rocm_aiter_fa is an optional backend # ROCM_AITER_FA is an optional backend
if find_spec("vllm.v1.attention.backends.rocm_aiter_fa"): if find_spec(
AttentionBackendEnum.ROCM_AITER_FA.get_path(include_classname=False)
):
from vllm.v1.attention.backends.rocm_aiter_fa import ( from vllm.v1.attention.backends.rocm_aiter_fa import (
AiterFlashAttentionMetadata, AiterFlashAttentionMetadata,
) )

View File

@ -4371,7 +4371,7 @@ class GPUModelRunner(LoRAModelRunnerMixin, KVConnectorModelRunnerMixin):
""" """
for backend in backends: for backend in backends:
is_supported = False is_supported = False
for supported_size in backend.get_supported_kernel_block_size(): for supported_size in backend.supported_kernel_block_sizes:
if isinstance(supported_size, int): if isinstance(supported_size, int):
if block_size == supported_size: if block_size == supported_size:
is_supported = True is_supported = True
@ -4402,7 +4402,7 @@ class GPUModelRunner(LoRAModelRunnerMixin, KVConnectorModelRunnerMixin):
all_int_supported_sizes = set( all_int_supported_sizes = set(
supported_size supported_size
for backend in backends for backend in backends
for supported_size in backend.get_supported_kernel_block_size() for supported_size in backend.supported_kernel_block_sizes
if isinstance(supported_size, int) if isinstance(supported_size, int)
) )