mirror of
https://git.datalinker.icu/vllm-project/vllm.git
synced 2026-03-19 08:07:07 +08:00
[Attention] Refactor CUDA attention backend selection logic (#24794)
Signed-off-by: Matthew Bonanni <mbonanni@redhat.com> Signed-off-by: Matthew Bonanni <mbonanni001@gmail.com> Co-authored-by: Luka Govedič <ProExpertProg@users.noreply.github.com>
This commit is contained in:
parent
2e78150d24
commit
b30dfa03c5
@ -890,11 +890,16 @@ steps:
|
||||
- vllm/model_executor/layers/fused_moe/flashinfer_cutlass_prepare_finalize.py
|
||||
- vllm/model_executor/layers/quantization/utils/flashinfer_utils.py
|
||||
- vllm/v1/attention/backends/flashinfer.py
|
||||
- vllm/v1/attention/backends/mla/cutlass_mla.py
|
||||
- vllm/v1/attention/backends/mla/flashinfer_mla.py
|
||||
- vllm/platforms/cuda.py
|
||||
- vllm/attention/selector.py
|
||||
commands:
|
||||
- nvidia-smi
|
||||
- python3 examples/offline_inference/basic/chat.py
|
||||
# Attention
|
||||
# num_heads2 broken by https://github.com/flashinfer-ai/flashinfer/issues/1353
|
||||
- pytest -v -s tests/kernels/attention/test_attention_selector.py
|
||||
- pytest -v -s tests/kernels/attention/test_flashinfer.py -k 'not num_heads2'
|
||||
- pytest -v -s tests/kernels/attention/test_flashinfer_trtllm_attention.py
|
||||
- pytest -v -s tests/kernels/attention/test_cutlass_mla_decode.py
|
||||
|
||||
@ -10,7 +10,7 @@ from tests.utils import flat_product
|
||||
from tests.v1.attention.utils import BatchSpec, create_common_attn_metadata
|
||||
from vllm._custom_ops import cutlass_scaled_fp4_mm, scaled_fp4_quant
|
||||
from vllm.attention import Attention, AttentionMetadata
|
||||
from vllm.attention.backends.registry import _Backend
|
||||
from vllm.attention.backends.registry import AttentionBackendEnum
|
||||
from vllm.attention.selector import global_force_attn_backend_context_manager
|
||||
from vllm.compilation.fusion_attn import ATTN_OP, AttnFusionPass
|
||||
from vllm.compilation.fx_utils import find_op_nodes
|
||||
@ -104,7 +104,7 @@ class AttentionQuantPatternModel(torch.nn.Module):
|
||||
|
||||
# TODO(luka) use get_kv_cache_stride_order
|
||||
# Create dummy KV cache for the selected backend
|
||||
if backend == _Backend.ROCM_ATTN:
|
||||
if backend == AttentionBackendEnum.ROCM_ATTN:
|
||||
# k/v as 1st dimention
|
||||
# HND: [num_blocks, num_kv_heads, block_size, head_size]
|
||||
kv_cache = torch.zeros(
|
||||
@ -116,7 +116,7 @@ class AttentionQuantPatternModel(torch.nn.Module):
|
||||
dtype=self.kv_cache_dtype,
|
||||
device=self.device,
|
||||
)
|
||||
elif backend == _Backend.ROCM_AITER_UNIFIED_ATTN:
|
||||
elif backend == AttentionBackendEnum.ROCM_AITER_UNIFIED_ATTN:
|
||||
# k/v as 1st dimention
|
||||
# NHD: [num_blocks, block_size, num_kv_heads, head_size]
|
||||
kv_cache = torch.zeros(
|
||||
@ -128,7 +128,7 @@ class AttentionQuantPatternModel(torch.nn.Module):
|
||||
dtype=self.kv_cache_dtype,
|
||||
device=self.device,
|
||||
)
|
||||
elif backend == _Backend.TRITON_ATTN:
|
||||
elif backend == AttentionBackendEnum.TRITON_ATTN:
|
||||
# k/v as 2nd dimention
|
||||
# NHD: [num_blocks, block_size, num_kv_heads, head_size]
|
||||
kv_cache = torch.zeros(
|
||||
@ -140,7 +140,7 @@ class AttentionQuantPatternModel(torch.nn.Module):
|
||||
dtype=self.kv_cache_dtype,
|
||||
device=self.device,
|
||||
)
|
||||
elif backend == _Backend.FLASHINFER:
|
||||
elif backend == AttentionBackendEnum.FLASHINFER:
|
||||
kv_cache = torch.zeros(
|
||||
num_blocks,
|
||||
2,
|
||||
@ -244,8 +244,8 @@ MODELS_FP8: list[tuple[str, type]] = []
|
||||
MODELS_FP4: list[tuple[str, type]] = []
|
||||
HEADS: list[tuple[int, int]] = []
|
||||
SPLIT_ATTENTION: list[bool] = []
|
||||
BACKENDS_FP8: list[_Backend] = []
|
||||
BACKENDS_FP4: list[_Backend] = []
|
||||
BACKENDS_FP8: list[AttentionBackendEnum] = []
|
||||
BACKENDS_FP4: list[AttentionBackendEnum] = []
|
||||
|
||||
if current_platform.is_cuda():
|
||||
HEADS = [(64, 8), (40, 8)]
|
||||
@ -261,8 +261,8 @@ if current_platform.is_cuda():
|
||||
TestAttentionNvfp4QuantPatternModel,
|
||||
)
|
||||
]
|
||||
BACKENDS_FP8 = [_Backend.TRITON_ATTN, _Backend.FLASHINFER]
|
||||
BACKENDS_FP4 = [_Backend.FLASHINFER]
|
||||
BACKENDS_FP8 = [AttentionBackendEnum.TRITON_ATTN, AttentionBackendEnum.FLASHINFER]
|
||||
BACKENDS_FP4 = [AttentionBackendEnum.FLASHINFER]
|
||||
|
||||
elif current_platform.is_rocm():
|
||||
HEADS = [(32, 8), (40, 8)]
|
||||
@ -270,9 +270,9 @@ elif current_platform.is_rocm():
|
||||
("amd/Llama-3.1-8B-Instruct-FP8-KV", TestAttentionFp8StaticQuantPatternModel)
|
||||
]
|
||||
BACKENDS = [
|
||||
_Backend.ROCM_AITER_UNIFIED_ATTN,
|
||||
_Backend.ROCM_ATTN,
|
||||
_Backend.TRITON_ATTN,
|
||||
AttentionBackendEnum.ROCM_AITER_UNIFIED_ATTN,
|
||||
AttentionBackendEnum.ROCM_ATTN,
|
||||
AttentionBackendEnum.TRITON_ATTN,
|
||||
]
|
||||
|
||||
|
||||
@ -302,11 +302,11 @@ def test_attention_quant_pattern(
|
||||
custom_ops: str,
|
||||
model_name: str,
|
||||
model_class: type[AttentionQuantPatternModel],
|
||||
backend: _Backend,
|
||||
backend: AttentionBackendEnum,
|
||||
dist_init,
|
||||
):
|
||||
"""Test AttentionStaticQuantPattern fusion pass"""
|
||||
if backend == _Backend.FLASHINFER and (
|
||||
if backend == AttentionBackendEnum.FLASHINFER and (
|
||||
not current_platform.is_device_capability((10, 0)) or not has_flashinfer()
|
||||
):
|
||||
pytest.skip("FlashInfer attn fusion requires Blackwell and flashinfer")
|
||||
@ -314,6 +314,7 @@ def test_attention_quant_pattern(
|
||||
custom_ops_list = custom_ops.split(",") if custom_ops else []
|
||||
|
||||
device = torch.device("cuda:0")
|
||||
torch.set_default_dtype(dtype)
|
||||
torch.manual_seed(42)
|
||||
|
||||
vllm_config = VllmConfig(
|
||||
@ -402,7 +403,7 @@ def test_attention_quant_pattern(
|
||||
|
||||
result_fused_1 = model_compiled(q, k, v)
|
||||
|
||||
if backend == _Backend.FLASHINFER:
|
||||
if backend == AttentionBackendEnum.FLASHINFER:
|
||||
# With the Flashinfer backend after the 1st round of the forward
|
||||
# pass, output quant scale should be loaded into the attn layer's
|
||||
# _o_scale_float, the 2nd round should reuse the loaded
|
||||
|
||||
@ -11,7 +11,7 @@ from typing import Any, NamedTuple
|
||||
import pytest
|
||||
import regex as re
|
||||
|
||||
from tests.v1.attention.utils import _Backend
|
||||
from tests.v1.attention.utils import AttentionBackendEnum
|
||||
from vllm import LLM, SamplingParams
|
||||
from vllm.config import CompilationConfig, CompilationMode, CUDAGraphMode, PassConfig
|
||||
from vllm.platforms import current_platform
|
||||
@ -24,7 +24,7 @@ from ..utils import flat_product, multi_gpu_test
|
||||
class ModelBackendTestCase(NamedTuple):
|
||||
model_name: str
|
||||
model_kwargs: dict[str, Any]
|
||||
backend: _Backend
|
||||
backend: AttentionBackendEnum
|
||||
attention_fusions: int
|
||||
allreduce_fusions: int | None = None
|
||||
|
||||
@ -39,14 +39,14 @@ if current_platform.is_cuda():
|
||||
# Use smaller model for L40s in CI
|
||||
model_name="RedHatAI/Meta-Llama-3.1-8B-Instruct-FP8",
|
||||
model_kwargs=dict(max_model_len=1024),
|
||||
backend=_Backend.TRITON_ATTN,
|
||||
backend=AttentionBackendEnum.TRITON_ATTN,
|
||||
attention_fusions=32,
|
||||
allreduce_fusions=65,
|
||||
),
|
||||
ModelBackendTestCase(
|
||||
model_name="nvidia/Llama-4-Scout-17B-16E-Instruct-FP8",
|
||||
model_kwargs=dict(max_model_len=1024, kv_cache_dtype="fp8"),
|
||||
backend=_Backend.FLASHINFER,
|
||||
backend=AttentionBackendEnum.FLASHINFER,
|
||||
attention_fusions=48,
|
||||
allreduce_fusions=96,
|
||||
),
|
||||
@ -56,7 +56,7 @@ if current_platform.is_cuda():
|
||||
ModelBackendTestCase(
|
||||
model_name="nvidia/Llama-3.1-8B-Instruct-FP4",
|
||||
model_kwargs=dict(max_model_len=1024, kv_cache_dtype="fp8"),
|
||||
backend=_Backend.FLASHINFER,
|
||||
backend=AttentionBackendEnum.FLASHINFER,
|
||||
attention_fusions=32,
|
||||
allreduce_fusions=65,
|
||||
),
|
||||
@ -67,7 +67,7 @@ if current_platform.is_cuda():
|
||||
ModelBackendTestCase(
|
||||
model_name="meta-llama/Llama-3.1-8B-Instruct",
|
||||
model_kwargs=dict(max_model_len=1024),
|
||||
backend=_Backend.TRITON_ATTN,
|
||||
backend=AttentionBackendEnum.TRITON_ATTN,
|
||||
attention_fusions=0,
|
||||
allreduce_fusions=65,
|
||||
),
|
||||
@ -85,19 +85,19 @@ elif current_platform.is_rocm():
|
||||
ModelBackendTestCase(
|
||||
model_name="amd/Llama-3.1-8B-Instruct-FP8-KV",
|
||||
model_kwargs=dict(max_model_len=1024),
|
||||
backend=_Backend.TRITON_ATTN,
|
||||
backend=AttentionBackendEnum.TRITON_ATTN,
|
||||
attention_fusions=32,
|
||||
),
|
||||
ModelBackendTestCase(
|
||||
model_name="amd/Llama-3.1-8B-Instruct-FP8-KV",
|
||||
model_kwargs=dict(max_model_len=1024),
|
||||
backend=_Backend.ROCM_ATTN,
|
||||
backend=AttentionBackendEnum.ROCM_ATTN,
|
||||
attention_fusions=32,
|
||||
),
|
||||
ModelBackendTestCase(
|
||||
model_name="amd/Llama-3.1-8B-Instruct-FP8-KV",
|
||||
model_kwargs=dict(max_model_len=1024),
|
||||
backend=_Backend.ROCM_AITER_UNIFIED_ATTN,
|
||||
backend=AttentionBackendEnum.ROCM_AITER_UNIFIED_ATTN,
|
||||
attention_fusions=32,
|
||||
),
|
||||
]
|
||||
@ -117,7 +117,7 @@ CUSTOM_OPS_FP8 = ["-quant_fp8", "+quant_fp8"]
|
||||
def test_attn_quant(
|
||||
model_name: str,
|
||||
model_kwargs: dict[str, Any],
|
||||
backend: _Backend,
|
||||
backend: AttentionBackendEnum,
|
||||
attention_fusions: int,
|
||||
allreduce_fusions: int,
|
||||
custom_ops: str,
|
||||
@ -125,7 +125,7 @@ def test_attn_quant(
|
||||
caplog_mp_spawn,
|
||||
monkeypatch,
|
||||
):
|
||||
if backend == _Backend.FLASHINFER and (
|
||||
if backend == AttentionBackendEnum.FLASHINFER and (
|
||||
not current_platform.is_device_capability((10, 0)) or not has_flashinfer()
|
||||
):
|
||||
pytest.skip("FlashInfer attn fusion requires Blackwell and flashinfer")
|
||||
@ -208,7 +208,7 @@ def custom_ops_product(*custom_ops_lists: list[str]) -> Iterable[str]:
|
||||
def test_tp2_attn_quant_allreduce_rmsnorm(
|
||||
model_name: str,
|
||||
model_kwargs: dict,
|
||||
backend: _Backend,
|
||||
backend: AttentionBackendEnum,
|
||||
attention_fusions: int,
|
||||
allreduce_fusions: int,
|
||||
custom_ops: str,
|
||||
|
||||
@ -3,13 +3,13 @@
|
||||
|
||||
import pytest
|
||||
|
||||
from vllm.attention.backends.registry import _Backend
|
||||
from vllm.attention.backends.registry import AttentionBackendEnum
|
||||
from vllm.config.multimodal import MultiModalConfig
|
||||
|
||||
|
||||
def test_mm_encoder_attn_backend_str_conversion():
|
||||
config = MultiModalConfig(mm_encoder_attn_backend="FLASH_ATTN")
|
||||
assert config.mm_encoder_attn_backend == _Backend.FLASH_ATTN
|
||||
assert config.mm_encoder_attn_backend == AttentionBackendEnum.FLASH_ATTN
|
||||
|
||||
|
||||
def test_mm_encoder_attn_backend_invalid():
|
||||
@ -20,6 +20,6 @@ def test_mm_encoder_attn_backend_invalid():
|
||||
def test_mm_encoder_attn_backend_hash_updates():
|
||||
base_hash = MultiModalConfig().compute_hash()
|
||||
overridden_hash = MultiModalConfig(
|
||||
mm_encoder_attn_backend=_Backend.FLASH_ATTN
|
||||
mm_encoder_attn_backend=AttentionBackendEnum.FLASH_ATTN
|
||||
).compute_hash()
|
||||
assert base_hash != overridden_hash
|
||||
|
||||
@ -120,12 +120,13 @@ def test_env(
|
||||
|
||||
elif device == "cuda":
|
||||
with patch("vllm.platforms.current_platform", CudaPlatform()):
|
||||
capability = torch.cuda.get_device_capability()
|
||||
if use_mla:
|
||||
# CUDA MLA backend logic:
|
||||
# - CUTLASS_MLA: only supported with block_size == 128
|
||||
# and Blackwell GPUs (SM 10.0), V1 only
|
||||
# and Blackwell GPUs (SM 10.x), V1 only
|
||||
# - FLASHINFER_MLA: only supported on Blackwell GPUs
|
||||
# (SM 10.0+), V1 only
|
||||
# (SM 10.x), V1 only
|
||||
# - FLASHMLA: only supported with block_size == 64
|
||||
# - FLASH_ATTN_MLA: V1 only
|
||||
# - TRITON_MLA: fallback for other cases
|
||||
@ -134,58 +135,72 @@ def test_env(
|
||||
if block_size != 128:
|
||||
# CUTLASS_MLA only supports block_size == 128
|
||||
pytest.skip("CUTLASS_MLA only supports block_size 128")
|
||||
else:
|
||||
backend = get_attn_backend(
|
||||
16, torch.float16, None, block_size, use_mla=use_mla
|
||||
)
|
||||
expected = "CUTLASS_MLA"
|
||||
assert backend.get_name() == expected
|
||||
if capability[0] != 10:
|
||||
pytest.skip("CUTLASS MLA is not supported on this platform")
|
||||
backend = get_attn_backend(
|
||||
576, torch.float16, None, block_size, use_mla=use_mla
|
||||
)
|
||||
expected = "CUTLASS_MLA"
|
||||
assert backend.get_name() == expected
|
||||
elif name == "FLASHINFER_MLA":
|
||||
if capability[0] != 10:
|
||||
pytest.skip(
|
||||
"FlashInfer MLA is not supported on this platform"
|
||||
)
|
||||
if block_size not in [32, 64]:
|
||||
# FlashInfer MLA only supports block_size 32 or 64
|
||||
pytest.skip(
|
||||
"FlashInfer MLA only supports block_size 32 or 64"
|
||||
)
|
||||
else:
|
||||
backend = get_attn_backend(
|
||||
16, torch.float16, None, block_size, use_mla=use_mla
|
||||
)
|
||||
expected = "FLASHINFER_MLA"
|
||||
assert backend.get_name() == expected
|
||||
backend = get_attn_backend(
|
||||
576, torch.float16, None, block_size, use_mla=use_mla
|
||||
)
|
||||
expected = "FLASHINFER_MLA"
|
||||
assert backend.get_name() == expected
|
||||
elif name == "FLASHMLA":
|
||||
if block_size != 64:
|
||||
# FlashMLA only supports block_size == 64
|
||||
pytest.skip("FlashMLA only supports block_size 64")
|
||||
else:
|
||||
from vllm.v1.attention.backends.mla.flashmla import (
|
||||
is_flashmla_dense_supported,
|
||||
)
|
||||
from vllm.v1.attention.backends.mla.flashmla import (
|
||||
is_flashmla_dense_supported,
|
||||
)
|
||||
|
||||
is_supported, _ = is_flashmla_dense_supported()
|
||||
if not is_supported:
|
||||
pytest.skip("FlashMLA not supported on this platform")
|
||||
else:
|
||||
backend = get_attn_backend(
|
||||
16, torch.float16, None, block_size, use_mla=use_mla
|
||||
)
|
||||
expected = name
|
||||
assert backend.get_name() == expected
|
||||
elif name == "FLASH_ATTN_MLA":
|
||||
is_supported, _ = is_flashmla_dense_supported()
|
||||
if not is_supported:
|
||||
pytest.skip("FlashMLA not supported on this platform")
|
||||
backend = get_attn_backend(
|
||||
16, torch.float16, None, block_size, use_mla=use_mla
|
||||
576,
|
||||
torch.float16,
|
||||
None,
|
||||
block_size,
|
||||
use_mla=use_mla,
|
||||
)
|
||||
expected = name
|
||||
assert backend.get_name() == expected
|
||||
elif name == "FLASH_ATTN_MLA":
|
||||
from vllm.attention.utils.fa_utils import (
|
||||
flash_attn_supports_mla,
|
||||
)
|
||||
|
||||
if not flash_attn_supports_mla():
|
||||
pytest.skip(
|
||||
"FlashAttention MLA not supported on this platform"
|
||||
)
|
||||
backend = get_attn_backend(
|
||||
576, torch.float16, None, block_size, use_mla=use_mla
|
||||
)
|
||||
expected = "FLASH_ATTN_MLA"
|
||||
assert backend.get_name() == expected
|
||||
else:
|
||||
# TRITON_MLA or other fallback
|
||||
backend = get_attn_backend(
|
||||
16, torch.float16, None, block_size, use_mla=use_mla
|
||||
576, torch.float16, None, block_size, use_mla=use_mla
|
||||
)
|
||||
expected = "TRITON_MLA"
|
||||
assert backend.get_name() == expected
|
||||
elif name == "FLASHINFER":
|
||||
backend = get_attn_backend(
|
||||
16, torch.float16, None, block_size, use_mla=use_mla
|
||||
64, torch.float16, None, block_size, use_mla=use_mla
|
||||
)
|
||||
expected = "FLASHINFER"
|
||||
assert backend.get_name() == expected
|
||||
|
||||
@ -11,7 +11,7 @@ from unittest.mock import patch
|
||||
import pytest
|
||||
import torch
|
||||
|
||||
from vllm.attention.backends.registry import _Backend
|
||||
from vllm.attention.backends.registry import AttentionBackendEnum
|
||||
from vllm.attention.layer import MultiHeadAttention
|
||||
from vllm.attention.selector import _cached_get_attn_backend
|
||||
from vllm.platforms import current_platform
|
||||
@ -43,14 +43,14 @@ def test_mha_attn_platform(device: str):
|
||||
patch("vllm.model_executor.models.vision.current_platform", CpuPlatform()),
|
||||
):
|
||||
attn = MultiHeadAttention(16, 64, scale=1)
|
||||
assert attn.attn_backend == _Backend.TORCH_SDPA
|
||||
assert attn.attn_backend == AttentionBackendEnum.TORCH_SDPA
|
||||
elif device == "hip":
|
||||
with (
|
||||
patch("vllm.attention.layer.current_platform", RocmPlatform()),
|
||||
patch("vllm.model_executor.models.vision.current_platform", RocmPlatform()),
|
||||
):
|
||||
attn = MultiHeadAttention(16, 64, scale=1)
|
||||
assert attn.attn_backend == _Backend.TORCH_SDPA
|
||||
assert attn.attn_backend == AttentionBackendEnum.TORCH_SDPA
|
||||
else:
|
||||
# Test CUDA with head_size=64 (divisible by 32)
|
||||
# - should use vLLM's FlashAttention
|
||||
@ -59,7 +59,7 @@ def test_mha_attn_platform(device: str):
|
||||
patch("vllm.model_executor.models.vision.current_platform", CudaPlatform()),
|
||||
):
|
||||
attn = MultiHeadAttention(16, 64, scale=1)
|
||||
assert attn.attn_backend == _Backend.FLASH_ATTN
|
||||
assert attn.attn_backend == AttentionBackendEnum.FLASH_ATTN
|
||||
|
||||
# Test CUDA with head_size=72 (not divisible by 32)
|
||||
# - with upstream FA not available
|
||||
@ -73,7 +73,7 @@ def test_mha_attn_platform(device: str):
|
||||
),
|
||||
):
|
||||
attn = MultiHeadAttention(16, 72, scale=1)
|
||||
assert attn.attn_backend == _Backend.XFORMERS
|
||||
assert attn.attn_backend == AttentionBackendEnum.XFORMERS
|
||||
|
||||
# Test CUDA with head_size=72 (not divisible by 32)
|
||||
# - with upstream FA available
|
||||
@ -96,7 +96,7 @@ def test_mha_attn_platform(device: str):
|
||||
),
|
||||
):
|
||||
attn = MultiHeadAttention(16, 72, scale=1)
|
||||
assert attn.attn_backend == _Backend.FLASH_ATTN
|
||||
assert attn.attn_backend == AttentionBackendEnum.FLASH_ATTN
|
||||
|
||||
|
||||
def ref_attention(
|
||||
|
||||
@ -93,6 +93,17 @@ def can_initialize(
|
||||
"pickle error when loading `transformers.models.auto.CONFIG_MAPPING`"
|
||||
)
|
||||
|
||||
if model_arch == "DeepseekV32ForCausalLM":
|
||||
from vllm.platforms import current_platform
|
||||
|
||||
capability = current_platform.get_device_capability()
|
||||
if capability and capability.major < 9:
|
||||
pytest.skip(
|
||||
f"DeepseekV32 requires Hopper (9.0+) or Blackwell (10.0+) "
|
||||
f"for FLASHMLA_SPARSE backend. Current device has compute "
|
||||
f"capability {capability.major}.{capability.minor}"
|
||||
)
|
||||
|
||||
with (
|
||||
patch.object(V1EngineCore, "_initialize_kv_caches", _initialize_kv_caches_v1),
|
||||
monkeypatch.context() as m,
|
||||
|
||||
@ -15,7 +15,7 @@ from tests.v1.attention.utils import (
|
||||
create_vllm_config,
|
||||
try_get_attention_backend,
|
||||
)
|
||||
from vllm.attention.backends.registry import _Backend
|
||||
from vllm.attention.backends.registry import AttentionBackendEnum
|
||||
from vllm.config import ModelConfig
|
||||
from vllm.platforms import current_platform
|
||||
from vllm.utils.math_utils import cdiv
|
||||
@ -27,11 +27,11 @@ from vllm.v1.attention.backends.utils import (
|
||||
from vllm.v1.kv_cache_interface import FullAttentionSpec
|
||||
|
||||
BACKENDS_TO_TEST = [
|
||||
_Backend.FLASH_ATTN,
|
||||
_Backend.FLASHINFER,
|
||||
_Backend.FLEX_ATTENTION,
|
||||
_Backend.TRITON_ATTN,
|
||||
_Backend.TREE_ATTN,
|
||||
AttentionBackendEnum.FLASH_ATTN,
|
||||
AttentionBackendEnum.FLASHINFER,
|
||||
AttentionBackendEnum.FLEX_ATTENTION,
|
||||
AttentionBackendEnum.TRITON_ATTN,
|
||||
AttentionBackendEnum.TREE_ATTN,
|
||||
"FLEX_ATTENTION_SLOW",
|
||||
]
|
||||
|
||||
@ -39,7 +39,7 @@ BACKENDS_TO_TEST = [
|
||||
try:
|
||||
import flashinfer # noqa: F401
|
||||
except ImportError:
|
||||
BACKENDS_TO_TEST.remove(_Backend.FLASHINFER)
|
||||
BACKENDS_TO_TEST.remove(AttentionBackendEnum.FLASHINFER)
|
||||
|
||||
|
||||
def _convert_dtype_to_torch(dtype):
|
||||
@ -192,7 +192,7 @@ class MockAttentionLayer:
|
||||
|
||||
|
||||
def run_attention_backend(
|
||||
backend: _Backend,
|
||||
backend: AttentionBackendEnum,
|
||||
kv_cache_spec: FullAttentionSpec,
|
||||
layer_names: list[str],
|
||||
vllm_config,
|
||||
@ -211,13 +211,13 @@ def run_attention_backend(
|
||||
|
||||
use_direct_block_mask = is_torch_equal_or_newer("2.9.0.dev0")
|
||||
if backend == "FLEX_ATTENTION_SLOW":
|
||||
actual_backend = _Backend.FLEX_ATTENTION
|
||||
actual_backend = AttentionBackendEnum.FLEX_ATTENTION
|
||||
use_direct_block_mask = False
|
||||
|
||||
builder_cls, impl_cls = try_get_attention_backend(actual_backend)
|
||||
|
||||
# Mock flashinfer's get_per_layer_parameters if needed
|
||||
if actual_backend == _Backend.FLASHINFER:
|
||||
if actual_backend == AttentionBackendEnum.FLASHINFER:
|
||||
import unittest.mock
|
||||
|
||||
from vllm.v1.attention.backends.utils import PerLayerParameters
|
||||
@ -246,7 +246,7 @@ def run_attention_backend(
|
||||
else:
|
||||
# Build metadata
|
||||
builder = builder_cls(kv_cache_spec, layer_names, vllm_config, device)
|
||||
if actual_backend == _Backend.FLEX_ATTENTION:
|
||||
if actual_backend == AttentionBackendEnum.FLEX_ATTENTION:
|
||||
builder.direct_build = use_direct_block_mask
|
||||
attn_metadata = builder.build(
|
||||
common_prefix_len=0,
|
||||
@ -289,7 +289,7 @@ def run_attention_backend(
|
||||
def _test_backend_correctness(
|
||||
batch_spec: BatchSpec,
|
||||
model: str,
|
||||
backend_to_test: list[_Backend | str],
|
||||
backend_to_test: list[AttentionBackendEnum | str],
|
||||
mask_mod,
|
||||
*,
|
||||
block_size: int = 16,
|
||||
@ -455,17 +455,20 @@ def _test_backend_correctness(
|
||||
# Select the appropriate KV cache format for each backend
|
||||
kv_cache_for_backend = kv_cache
|
||||
reset_kv_cache_layout = False
|
||||
if backend_name in (_Backend.FLASHINFER, _Backend.TRITON_ATTN):
|
||||
if backend_name in (
|
||||
AttentionBackendEnum.FLASHINFER,
|
||||
AttentionBackendEnum.TRITON_ATTN,
|
||||
):
|
||||
kv_cache_for_backend = kv_cache.transpose(0, 1)
|
||||
|
||||
if backend_name == _Backend.FLASHINFER:
|
||||
if backend_name == AttentionBackendEnum.FLASHINFER:
|
||||
# For FlashInfer default to HND layout and
|
||||
kv_cache_for_backend = (
|
||||
kv_cache_for_backend.transpose(2, 3).contiguous().transpose(2, 3)
|
||||
)
|
||||
set_kv_cache_layout("HND")
|
||||
reset_kv_cache_layout = True
|
||||
elif backend_name == _Backend.TRITON_ATTN:
|
||||
elif backend_name == AttentionBackendEnum.TRITON_ATTN:
|
||||
kv_cache_for_backend = kv_cache_for_backend.contiguous()
|
||||
|
||||
try:
|
||||
@ -547,7 +550,9 @@ def test_causal_backend_correctness(
|
||||
|
||||
batch_spec = BATCH_SPECS[batch_spec_name]
|
||||
LARGE_BLOCK_BACKENDS = (
|
||||
[_Backend.FLEX_ATTENTION] if is_torch_equal_or_newer("2.9.0.dev0") else []
|
||||
[AttentionBackendEnum.FLEX_ATTENTION]
|
||||
if is_torch_equal_or_newer("2.9.0.dev0")
|
||||
else []
|
||||
)
|
||||
SMALL_BLOCK_BACKENDS = [
|
||||
x for x in BACKENDS_TO_TEST if x not in LARGE_BLOCK_BACKENDS
|
||||
@ -573,9 +578,9 @@ def test_causal_backend_correctness(
|
||||
|
||||
|
||||
SLIDING_WINDOW_BACKENDS_TO_TEST = [
|
||||
_Backend.FLASH_ATTN,
|
||||
_Backend.FLEX_ATTENTION,
|
||||
_Backend.TRITON_ATTN,
|
||||
AttentionBackendEnum.FLASH_ATTN,
|
||||
AttentionBackendEnum.FLEX_ATTENTION,
|
||||
AttentionBackendEnum.TRITON_ATTN,
|
||||
"FLEX_ATTENTION_SLOW",
|
||||
]
|
||||
|
||||
@ -612,7 +617,9 @@ def test_sliding_window_backend_correctness(
|
||||
)
|
||||
|
||||
LARGE_BLOCK_BACKENDS = (
|
||||
[_Backend.FLEX_ATTENTION] if is_torch_equal_or_newer("2.9.0.dev0") else []
|
||||
[AttentionBackendEnum.FLEX_ATTENTION]
|
||||
if is_torch_equal_or_newer("2.9.0.dev0")
|
||||
else []
|
||||
)
|
||||
SMALL_BLOCK_BACKENDS = [
|
||||
x for x in SLIDING_WINDOW_BACKENDS_TO_TEST if x not in LARGE_BLOCK_BACKENDS
|
||||
|
||||
@ -18,12 +18,11 @@ from tests.v1.attention.utils import (
|
||||
try_get_attention_backend,
|
||||
)
|
||||
from vllm import _custom_ops as ops
|
||||
from vllm.attention.backends.registry import _Backend, backend_to_class_str
|
||||
from vllm.attention.backends.registry import AttentionBackendEnum
|
||||
from vllm.attention.ops.flashmla import is_flashmla_dense_supported
|
||||
from vllm.attention.utils.fa_utils import flash_attn_supports_mla
|
||||
from vllm.config.vllm import set_current_vllm_config
|
||||
from vllm.model_executor.layers.attention_layer_base import AttentionLayerBase
|
||||
from vllm.utils.import_utils import resolve_obj_by_qualname
|
||||
from vllm.utils.math_utils import cdiv
|
||||
from vllm.utils.torch_utils import STR_DTYPE_TO_TORCH_DTYPE
|
||||
from vllm.v1.attention.backends.mla.common import QueryLenSupport
|
||||
@ -31,25 +30,25 @@ from vllm.v1.attention.backends.utils import CommonAttentionMetadata
|
||||
from vllm.v1.kv_cache_interface import FullAttentionSpec
|
||||
|
||||
BACKENDS_TO_TEST = [
|
||||
_Backend.CUTLASS_MLA,
|
||||
_Backend.FLASHMLA,
|
||||
_Backend.FLASH_ATTN_MLA,
|
||||
_Backend.FLASHINFER_MLA,
|
||||
_Backend.TRITON_MLA,
|
||||
AttentionBackendEnum.CUTLASS_MLA,
|
||||
AttentionBackendEnum.FLASHMLA,
|
||||
AttentionBackendEnum.FLASH_ATTN_MLA,
|
||||
AttentionBackendEnum.FLASHINFER_MLA,
|
||||
AttentionBackendEnum.TRITON_MLA,
|
||||
]
|
||||
|
||||
# Remove sm100 backends from the list if not using sm100
|
||||
if not torch.cuda.is_available() or torch.cuda.get_device_properties(0).major < 10:
|
||||
BACKENDS_TO_TEST.remove(_Backend.CUTLASS_MLA)
|
||||
BACKENDS_TO_TEST.remove(_Backend.FLASHINFER_MLA)
|
||||
BACKENDS_TO_TEST.remove(AttentionBackendEnum.CUTLASS_MLA)
|
||||
BACKENDS_TO_TEST.remove(AttentionBackendEnum.FLASHINFER_MLA)
|
||||
|
||||
# Remove FLASH_ATTN_MLA from the list if not supported
|
||||
if not flash_attn_supports_mla():
|
||||
BACKENDS_TO_TEST.remove(_Backend.FLASH_ATTN_MLA)
|
||||
BACKENDS_TO_TEST.remove(AttentionBackendEnum.FLASH_ATTN_MLA)
|
||||
|
||||
# Remove FLASHMLA from the list if not supported
|
||||
if not is_flashmla_dense_supported()[0]:
|
||||
BACKENDS_TO_TEST.remove(_Backend.FLASHMLA)
|
||||
BACKENDS_TO_TEST.remove(AttentionBackendEnum.FLASHMLA)
|
||||
|
||||
SPEC_DECODE_BACKENDS = []
|
||||
for backend in BACKENDS_TO_TEST:
|
||||
@ -62,9 +61,7 @@ for backend in BACKENDS_TO_TEST:
|
||||
|
||||
BACKEND_BLOCK_SIZES = {}
|
||||
for backend in BACKENDS_TO_TEST:
|
||||
backend_class_str = backend_to_class_str(backend)
|
||||
backend_class = resolve_obj_by_qualname(backend_class_str)
|
||||
supported_sizes = backend_class.get_supported_kernel_block_size()
|
||||
supported_sizes = backend.get_class().supported_kernel_block_sizes
|
||||
if supported_sizes:
|
||||
default_size = supported_sizes[0]
|
||||
block_size = (
|
||||
@ -291,7 +288,7 @@ class MockMLAAttentionLayer(AttentionLayerBase):
|
||||
|
||||
|
||||
def run_attention_backend(
|
||||
backend: _Backend,
|
||||
backend: AttentionBackendEnum,
|
||||
kv_cache_spec: FullAttentionSpec,
|
||||
layer_names: list[str],
|
||||
vllm_config,
|
||||
@ -813,7 +810,7 @@ def test_backend_correctness(
|
||||
# Create a summary for the single-line failure message
|
||||
backend_names = []
|
||||
for f in failures:
|
||||
if "[_Backend." in f:
|
||||
if "[AttentionBackendEnum." in f:
|
||||
backend_name = f.split("[")[1].split("]")[0]
|
||||
backend_names.append(backend_name)
|
||||
|
||||
|
||||
@ -8,7 +8,7 @@ import pytest
|
||||
import torch
|
||||
|
||||
from vllm.attention.backends.abstract import AttentionImpl
|
||||
from vllm.attention.backends.registry import _Backend, backend_to_class_str
|
||||
from vllm.attention.backends.registry import AttentionBackendEnum
|
||||
from vllm.config import (
|
||||
CacheConfig,
|
||||
CompilationConfig,
|
||||
@ -20,7 +20,6 @@ from vllm.config import (
|
||||
VllmConfig,
|
||||
)
|
||||
from vllm.config.model import ModelDType
|
||||
from vllm.utils.import_utils import resolve_obj_by_qualname
|
||||
from vllm.v1.attention.backends.utils import (
|
||||
AttentionMetadataBuilder,
|
||||
CommonAttentionMetadata,
|
||||
@ -120,15 +119,14 @@ def create_common_attn_metadata(
|
||||
|
||||
|
||||
def try_get_attention_backend(
|
||||
backend: _Backend,
|
||||
backend: AttentionBackendEnum,
|
||||
) -> tuple[type[AttentionMetadataBuilder], type[AttentionImpl]]:
|
||||
"""Try to get the attention backend class, skipping test if not found."""
|
||||
backend_class_str = backend_to_class_str(backend)
|
||||
try:
|
||||
backend_class = resolve_obj_by_qualname(backend_class_str)
|
||||
backend_class = backend.get_class()
|
||||
return backend_class.get_builder_cls(), backend_class.get_impl_cls()
|
||||
except ImportError as e:
|
||||
pytest.skip(f"{backend_class_str} not available: {e}")
|
||||
pytest.skip(f"{backend.name} not available: {e}")
|
||||
raise AssertionError("unreachable") from None
|
||||
|
||||
|
||||
|
||||
@ -13,7 +13,7 @@ from tests.v1.attention.utils import (
|
||||
create_standard_kv_cache_spec,
|
||||
try_get_attention_backend,
|
||||
)
|
||||
from vllm.attention.backends.registry import _Backend
|
||||
from vllm.attention.backends.registry import AttentionBackendEnum
|
||||
from vllm.config import (
|
||||
CacheConfig,
|
||||
DeviceConfig,
|
||||
@ -534,11 +534,17 @@ def test_propose(method, attn_backend, num_speculative_tokens, monkeypatch):
|
||||
sampling_metadata = mock.MagicMock()
|
||||
|
||||
if attn_backend == "FLASH_ATTN":
|
||||
attn_metadata_builder_cls, _ = try_get_attention_backend(_Backend.FLASH_ATTN)
|
||||
attn_metadata_builder_cls, _ = try_get_attention_backend(
|
||||
AttentionBackendEnum.FLASH_ATTN
|
||||
)
|
||||
elif attn_backend == "TRITON_ATTN":
|
||||
attn_metadata_builder_cls, _ = try_get_attention_backend(_Backend.TRITON_ATTN)
|
||||
attn_metadata_builder_cls, _ = try_get_attention_backend(
|
||||
AttentionBackendEnum.TRITON_ATTN
|
||||
)
|
||||
elif attn_backend == "TREE_ATTN":
|
||||
attn_metadata_builder_cls, _ = try_get_attention_backend(_Backend.TREE_ATTN)
|
||||
attn_metadata_builder_cls, _ = try_get_attention_backend(
|
||||
AttentionBackendEnum.TREE_ATTN
|
||||
)
|
||||
else:
|
||||
raise ValueError(f"Unsupported attention backend: {attn_backend}")
|
||||
|
||||
@ -673,7 +679,9 @@ def test_propose_tree(spec_token_tree):
|
||||
proposer.attn_layer_names = ["layer.0"]
|
||||
|
||||
# Get the tree attention metadata builder.
|
||||
attn_metadata_builder_cls, _ = try_get_attention_backend(_Backend.TREE_ATTN)
|
||||
attn_metadata_builder_cls, _ = try_get_attention_backend(
|
||||
AttentionBackendEnum.TREE_ATTN
|
||||
)
|
||||
attn_metadata_builder = attn_metadata_builder_cls(
|
||||
kv_cache_spec=create_standard_kv_cache_spec(proposer.vllm_config),
|
||||
layer_names=proposer.attn_layer_names,
|
||||
|
||||
@ -12,7 +12,7 @@ from tests.v1.attention.utils import (
|
||||
create_standard_kv_cache_spec,
|
||||
try_get_attention_backend,
|
||||
)
|
||||
from vllm.attention.backends.registry import _Backend
|
||||
from vllm.attention.backends.registry import AttentionBackendEnum
|
||||
from vllm.config import (
|
||||
CacheConfig,
|
||||
DeviceConfig,
|
||||
@ -177,7 +177,9 @@ def test_mtp_propose(num_speculative_tokens, monkeypatch):
|
||||
sampling_metadata = mock.MagicMock()
|
||||
|
||||
# Setup attention metadata
|
||||
attn_metadata_builder_cls, _ = try_get_attention_backend(_Backend.FLASH_ATTN)
|
||||
attn_metadata_builder_cls, _ = try_get_attention_backend(
|
||||
AttentionBackendEnum.FLASH_ATTN
|
||||
)
|
||||
|
||||
attn_metadata_builder = attn_metadata_builder_cls(
|
||||
kv_cache_spec=create_standard_kv_cache_spec(proposer.vllm_config),
|
||||
|
||||
@ -10,7 +10,7 @@ from tests.v1.attention.utils import (
|
||||
create_vllm_config,
|
||||
try_get_attention_backend,
|
||||
)
|
||||
from vllm.attention.backends.registry import _Backend
|
||||
from vllm.attention.backends.registry import AttentionBackendEnum
|
||||
from vllm.config import ParallelConfig, SpeculativeConfig
|
||||
from vllm.v1.attention.backends.utils import CommonAttentionMetadata
|
||||
|
||||
@ -35,7 +35,7 @@ def forward_attention(
|
||||
block_table: torch.Tensor,
|
||||
slot_mapping: torch.Tensor,
|
||||
seqlen_k: int,
|
||||
backend: _Backend,
|
||||
backend: AttentionBackendEnum,
|
||||
spec_token_tree: str | None = None,
|
||||
num_spec_tokens: int = 0,
|
||||
) -> torch.Tensor:
|
||||
@ -241,7 +241,7 @@ def test_tree_attn_correctness() -> None:
|
||||
block_table=block_table,
|
||||
slot_mapping=tree_slot_mapping,
|
||||
seqlen_k=seqlen_k,
|
||||
backend=_Backend.TREE_ATTN,
|
||||
backend=AttentionBackendEnum.TREE_ATTN,
|
||||
spec_token_tree=spec_token_tree,
|
||||
num_spec_tokens=tree_size_q - 1,
|
||||
).view(batch_size, -1, num_heads, dim_per_head)
|
||||
@ -278,7 +278,7 @@ def test_tree_attn_correctness() -> None:
|
||||
block_table=block_table,
|
||||
slot_mapping=branch_slot_mapping,
|
||||
seqlen_k=sequence_position + q_len,
|
||||
backend=_Backend.FLASH_ATTN,
|
||||
backend=AttentionBackendEnum.FLASH_ATTN,
|
||||
).view(batch_size, -1, num_heads, dim_per_head)
|
||||
|
||||
# Compare the outputs.
|
||||
|
||||
@ -185,9 +185,7 @@ def _make_mock_backend_for_kernel_block_size(
|
||||
supported_sizes: list[int | MultipleOf],
|
||||
):
|
||||
class _MockBackend:
|
||||
@staticmethod
|
||||
def get_supported_kernel_block_size():
|
||||
return supported_sizes
|
||||
supported_kernel_block_sizes = supported_sizes
|
||||
|
||||
return _MockBackend()
|
||||
|
||||
@ -466,13 +464,20 @@ def test_kv_cache_stride_order(monkeypatch, model_runner):
|
||||
# This test checks if GPUModelRunner initializes correctly when an attention
|
||||
# backend enforces a non-default KV cache stride order.
|
||||
n_heads = model_runner.model_config.get_num_kv_heads(model_runner.parallel_config)
|
||||
expected_kv_cache_shape = [
|
||||
2,
|
||||
NUM_BLOCKS,
|
||||
BLOCK_SIZE,
|
||||
n_heads,
|
||||
model_runner.model_config.get_head_size(),
|
||||
]
|
||||
head_size = model_runner.model_config.get_head_size()
|
||||
|
||||
# Get the expected shape from the backend's get_kv_cache_shape method
|
||||
# to ensure compatibility with different backends (triton vs flexattention)
|
||||
attn_backend = None
|
||||
for attn_group in model_runner._attn_group_iterator():
|
||||
attn_backend = attn_group.backend
|
||||
break
|
||||
|
||||
assert attn_backend is not None, "No attention backend found"
|
||||
expected_kv_cache_shape = list(
|
||||
attn_backend.get_kv_cache_shape(NUM_BLOCKS, BLOCK_SIZE, n_heads, head_size)
|
||||
)
|
||||
|
||||
# TODO mla test
|
||||
default_stride = tuple(range(5))
|
||||
# Permutation that gets you back to expected kv shape
|
||||
|
||||
@ -2,13 +2,18 @@
|
||||
# SPDX-FileCopyrightText: Copyright contributors to the vLLM project
|
||||
|
||||
from abc import ABC, abstractmethod
|
||||
from typing import Generic, Protocol, TypeVar
|
||||
from typing import TYPE_CHECKING, ClassVar, Generic, Protocol, TypeVar, get_args
|
||||
|
||||
import torch
|
||||
|
||||
from vllm.model_executor.layers.linear import ColumnParallelLinear
|
||||
from vllm.model_executor.layers.quantization.utils.quant_utils import QuantKey
|
||||
|
||||
if TYPE_CHECKING:
|
||||
from vllm.config.cache import CacheDType
|
||||
from vllm.platforms.interface import DeviceCapability
|
||||
from vllm.v1.attention.backends.utils import KVCacheLayoutType
|
||||
|
||||
|
||||
class AttentionType:
|
||||
"""
|
||||
@ -40,6 +45,9 @@ class AttentionBackend(ABC):
|
||||
# calling the custom op. When piecewise cudagraph is enabled, this
|
||||
# makes sure the output tensor is allocated inside the cudagraph.
|
||||
accept_output_buffer: bool = False
|
||||
supported_dtypes: ClassVar[list[torch.dtype]] = [torch.float16, torch.bfloat16]
|
||||
supported_kernel_block_sizes: ClassVar[list[int | MultipleOf]] = [MultipleOf(1)]
|
||||
supported_kv_cache_dtypes: ClassVar[list["CacheDType"]] = ["auto"]
|
||||
|
||||
@staticmethod
|
||||
@abstractmethod
|
||||
@ -51,10 +59,6 @@ class AttentionBackend(ABC):
|
||||
def get_impl_cls() -> type["AttentionImpl"]:
|
||||
raise NotImplementedError
|
||||
|
||||
@classmethod
|
||||
def get_supported_kernel_block_size(cls) -> list[int | MultipleOf]:
|
||||
return cls.get_impl_cls().get_supported_kernel_block_size()
|
||||
|
||||
@staticmethod
|
||||
@abstractmethod
|
||||
def get_builder_cls(): # -> Type["AttentionMetadataBuilder"]:
|
||||
@ -79,6 +83,136 @@ class AttentionBackend(ABC):
|
||||
def full_cls_name(cls) -> tuple[str, str]:
|
||||
return (cls.__module__, cls.__qualname__)
|
||||
|
||||
@classmethod
|
||||
def get_supported_head_sizes(cls) -> list[int]:
|
||||
return []
|
||||
|
||||
@classmethod
|
||||
def supports_head_size(cls, head_size: int) -> bool:
|
||||
supported_head_sizes = cls.get_supported_head_sizes()
|
||||
return (not supported_head_sizes) or head_size in supported_head_sizes
|
||||
|
||||
@classmethod
|
||||
def supports_dtype(cls, dtype: torch.dtype) -> bool:
|
||||
return dtype in cls.supported_dtypes
|
||||
|
||||
@classmethod
|
||||
def supports_kv_cache_dtype(cls, kv_cache_dtype: "CacheDType | None") -> bool:
|
||||
if kv_cache_dtype is None:
|
||||
return True
|
||||
return (not cls.supported_kv_cache_dtypes) or (
|
||||
kv_cache_dtype in cls.supported_kv_cache_dtypes
|
||||
)
|
||||
|
||||
@classmethod
|
||||
def supports_block_size(cls, block_size: int | None) -> bool:
|
||||
from vllm.config.cache import BlockSize
|
||||
|
||||
if block_size is None:
|
||||
return True
|
||||
|
||||
valid_sizes = get_args(BlockSize)
|
||||
if block_size not in valid_sizes:
|
||||
return False
|
||||
|
||||
if not cls.supported_kernel_block_sizes:
|
||||
return True
|
||||
|
||||
for supported_size in cls.supported_kernel_block_sizes:
|
||||
is_multiple_of = (
|
||||
isinstance(supported_size, MultipleOf)
|
||||
and block_size % supported_size.base == 0
|
||||
)
|
||||
is_int_equal = (
|
||||
isinstance(supported_size, int) and block_size == supported_size
|
||||
)
|
||||
if is_multiple_of or is_int_equal:
|
||||
return True
|
||||
return False
|
||||
|
||||
@classmethod
|
||||
def is_mla(cls) -> bool:
|
||||
return False
|
||||
|
||||
@classmethod
|
||||
def supports_sink(cls) -> bool:
|
||||
return False
|
||||
|
||||
@classmethod
|
||||
def is_sparse(cls) -> bool:
|
||||
return False
|
||||
|
||||
@classmethod
|
||||
def supports_compute_capability(cls, capability: "DeviceCapability") -> bool:
|
||||
return True
|
||||
|
||||
@classmethod
|
||||
def supports_combination(
|
||||
cls,
|
||||
head_size: int,
|
||||
dtype: torch.dtype,
|
||||
kv_cache_dtype: "CacheDType | None",
|
||||
block_size: int | None,
|
||||
use_mla: bool,
|
||||
has_sink: bool,
|
||||
use_sparse: bool,
|
||||
device_capability: "DeviceCapability",
|
||||
) -> str | None:
|
||||
return None
|
||||
|
||||
@classmethod
|
||||
def validate_configuration(
|
||||
cls,
|
||||
head_size: int,
|
||||
dtype: torch.dtype,
|
||||
kv_cache_dtype: "CacheDType | None",
|
||||
block_size: int | None,
|
||||
use_mla: bool,
|
||||
has_sink: bool,
|
||||
use_sparse: bool,
|
||||
device_capability: "DeviceCapability",
|
||||
) -> list[str]:
|
||||
invalid_reasons = []
|
||||
if not cls.supports_head_size(head_size):
|
||||
invalid_reasons.append("head_size not supported")
|
||||
if not cls.supports_dtype(dtype):
|
||||
invalid_reasons.append("dtype not supported")
|
||||
if not cls.supports_kv_cache_dtype(kv_cache_dtype):
|
||||
invalid_reasons.append("kv_cache_dtype not supported")
|
||||
if not cls.supports_block_size(block_size):
|
||||
invalid_reasons.append("block_size not supported")
|
||||
if use_mla != cls.is_mla():
|
||||
if use_mla:
|
||||
invalid_reasons.append("MLA not supported")
|
||||
else:
|
||||
invalid_reasons.append("non-MLA not supported")
|
||||
if has_sink and not cls.supports_sink():
|
||||
invalid_reasons.append("sink setting not supported")
|
||||
if use_sparse != cls.is_sparse():
|
||||
if use_sparse:
|
||||
invalid_reasons.append("sparse not supported")
|
||||
else:
|
||||
invalid_reasons.append("non-sparse not supported")
|
||||
if not cls.supports_compute_capability(device_capability):
|
||||
invalid_reasons.append("compute capability not supported")
|
||||
combination_reason = cls.supports_combination(
|
||||
head_size,
|
||||
dtype,
|
||||
kv_cache_dtype,
|
||||
block_size,
|
||||
use_mla,
|
||||
has_sink,
|
||||
use_sparse,
|
||||
device_capability,
|
||||
)
|
||||
if combination_reason is not None:
|
||||
invalid_reasons.append(combination_reason)
|
||||
return invalid_reasons
|
||||
|
||||
@classmethod
|
||||
def get_required_kv_cache_layout(cls) -> "KVCacheLayoutType | None":
|
||||
return None
|
||||
|
||||
|
||||
class AttentionMetadata:
|
||||
pass
|
||||
@ -151,11 +285,6 @@ class AttentionImpl(ABC, Generic[T]):
|
||||
) -> None:
|
||||
raise NotImplementedError
|
||||
|
||||
@staticmethod
|
||||
def get_supported_kernel_block_size() -> list[int | MultipleOf]:
|
||||
# TODO: implement this function for all backends.
|
||||
return [MultipleOf(1)]
|
||||
|
||||
@abstractmethod
|
||||
def forward(
|
||||
self,
|
||||
|
||||
@ -3,108 +3,192 @@
|
||||
"""Attention backend registry"""
|
||||
|
||||
import enum
|
||||
from collections.abc import Callable
|
||||
from typing import TYPE_CHECKING, cast
|
||||
|
||||
from vllm.logger import init_logger
|
||||
from vllm.utils.import_utils import resolve_obj_by_qualname
|
||||
|
||||
if TYPE_CHECKING:
|
||||
from vllm.attention.backends.abstract import AttentionBackend
|
||||
|
||||
class _Backend(enum.Enum):
|
||||
FLASH_ATTN = enum.auto()
|
||||
TRITON_ATTN = enum.auto()
|
||||
XFORMERS = enum.auto()
|
||||
ROCM_ATTN = enum.auto()
|
||||
ROCM_AITER_MLA = enum.auto()
|
||||
ROCM_AITER_FA = enum.auto() # used for ViT attn backend
|
||||
TORCH_SDPA = enum.auto()
|
||||
FLASHINFER = enum.auto()
|
||||
FLASHINFER_MLA = enum.auto()
|
||||
TRITON_MLA = enum.auto()
|
||||
CUTLASS_MLA = enum.auto()
|
||||
FLASHMLA = enum.auto()
|
||||
FLASHMLA_SPARSE = enum.auto()
|
||||
FLASH_ATTN_MLA = enum.auto()
|
||||
PALLAS = enum.auto()
|
||||
IPEX = enum.auto()
|
||||
NO_ATTENTION = enum.auto()
|
||||
FLEX_ATTENTION = enum.auto()
|
||||
TREE_ATTN = enum.auto()
|
||||
ROCM_AITER_UNIFIED_ATTN = enum.auto()
|
||||
logger = init_logger(__name__)
|
||||
|
||||
|
||||
BACKEND_MAP = {
|
||||
_Backend.FLASH_ATTN: "vllm.v1.attention.backends.flash_attn.FlashAttentionBackend", # noqa: E501
|
||||
_Backend.TRITON_ATTN: "vllm.v1.attention.backends.triton_attn.TritonAttentionBackend", # noqa: E501
|
||||
_Backend.XFORMERS: "vllm.v1.attention.backends.xformers.XFormersAttentionBackend", # noqa: E501
|
||||
_Backend.ROCM_ATTN: "vllm.v1.attention.backends.rocm_attn.RocmAttentionBackend", # noqa: E501
|
||||
_Backend.ROCM_AITER_MLA: "vllm.v1.attention.backends.mla.rocm_aiter_mla.AiterMLABackend", # noqa: E501
|
||||
_Backend.ROCM_AITER_FA: "vllm.v1.attention.backends.rocm_aiter_fa.AiterFlashAttentionBackend", # noqa: E501
|
||||
_Backend.TORCH_SDPA: "vllm.v1.attention.backends.cpu_attn.TorchSDPABackend", # noqa: E501
|
||||
_Backend.FLASHINFER: "vllm.v1.attention.backends.flashinfer.FlashInferBackend", # noqa: E501
|
||||
_Backend.FLASHINFER_MLA: "vllm.v1.attention.backends.mla.flashinfer_mla.FlashInferMLABackend", # noqa: E501
|
||||
_Backend.TRITON_MLA: "vllm.v1.attention.backends.mla.triton_mla.TritonMLABackend", # noqa: E501
|
||||
_Backend.CUTLASS_MLA: "vllm.v1.attention.backends.mla.cutlass_mla.CutlassMLABackend", # noqa: E501
|
||||
_Backend.FLASHMLA: "vllm.v1.attention.backends.mla.flashmla.FlashMLABackend", # noqa: E501
|
||||
_Backend.FLASHMLA_SPARSE: "vllm.v1.attention.backends.mla.flashmla_sparse.FlashMLASparseBackend", # noqa: E501
|
||||
_Backend.FLASH_ATTN_MLA: "vllm.v1.attention.backends.mla.flashattn_mla.FlashAttnMLABackend", # noqa: E501
|
||||
_Backend.PALLAS: "vllm.v1.attention.backends.pallas.PallasAttentionBackend", # noqa: E501
|
||||
_Backend.FLEX_ATTENTION: "vllm.v1.attention.backends.flex_attention.FlexAttentionBackend", # noqa: E501
|
||||
_Backend.TREE_ATTN: "vllm.v1.attention.backends.tree_attn.TreeAttentionBackend", # noqa: E501
|
||||
_Backend.ROCM_AITER_UNIFIED_ATTN: "vllm.v1.attention.backends.rocm_aiter_unified_attn.RocmAiterUnifiedAttentionBackend", # noqa: E501
|
||||
}
|
||||
class _AttentionBackendEnumMeta(enum.EnumMeta):
|
||||
"""Metaclass for AttentionBackendEnum to provide better error messages."""
|
||||
|
||||
def __getitem__(cls, name: str):
|
||||
"""Get backend by name with helpful error messages."""
|
||||
try:
|
||||
return super().__getitem__(name)
|
||||
except KeyError:
|
||||
members = cast("dict[str, AttentionBackendEnum]", cls.__members__).values()
|
||||
valid_backends = ", ".join(m.name for m in members)
|
||||
raise ValueError(
|
||||
f"Unknown attention backend: '{name}'. "
|
||||
f"Valid options are: {valid_backends}"
|
||||
) from None
|
||||
|
||||
|
||||
def register_attn_backend(backend: _Backend, class_path: str | None = None):
|
||||
class AttentionBackendEnum(enum.Enum, metaclass=_AttentionBackendEnumMeta):
|
||||
"""Enumeration of all supported attention backends.
|
||||
|
||||
The enum value is the default class path, but this can be overridden
|
||||
at runtime using register_backend().
|
||||
|
||||
To get the actual backend class (respecting overrides), use:
|
||||
backend.get_class()
|
||||
"""
|
||||
Decorator: register a custom attention backend into BACKEND_MAPPING.
|
||||
- If class_path is provided, use it.
|
||||
- Otherwise, auto-generate from the class object.
|
||||
Validation: only checks if 'backend' is a valid _Backend enum member.
|
||||
Overwriting existing mappings is allowed. This enables other hardware
|
||||
platforms to plug in custom out-of-tree backends.
|
||||
"""
|
||||
if not isinstance(backend, _Backend):
|
||||
raise ValueError(f"{backend} is not a valid _Backend enum value.")
|
||||
|
||||
def decorator(cls):
|
||||
path = class_path or f"{cls.__module__}.{cls.__qualname__}"
|
||||
BACKEND_MAP[backend] = path
|
||||
FLASH_ATTN = "vllm.v1.attention.backends.flash_attn.FlashAttentionBackend"
|
||||
TRITON_ATTN = "vllm.v1.attention.backends.triton_attn.TritonAttentionBackend"
|
||||
XFORMERS = "vllm.v1.attention.backends.xformers.XFormersAttentionBackend"
|
||||
ROCM_ATTN = "vllm.v1.attention.backends.rocm_attn.RocmAttentionBackend"
|
||||
ROCM_AITER_MLA = "vllm.v1.attention.backends.mla.rocm_aiter_mla.AiterMLABackend"
|
||||
ROCM_AITER_FA = (
|
||||
"vllm.v1.attention.backends.rocm_aiter_fa.AiterFlashAttentionBackend"
|
||||
)
|
||||
TORCH_SDPA = "vllm.v1.attention.backends.cpu_attn.TorchSDPABackend"
|
||||
FLASHINFER = "vllm.v1.attention.backends.flashinfer.FlashInferBackend"
|
||||
FLASHINFER_MLA = (
|
||||
"vllm.v1.attention.backends.mla.flashinfer_mla.FlashInferMLABackend"
|
||||
)
|
||||
TRITON_MLA = "vllm.v1.attention.backends.mla.triton_mla.TritonMLABackend"
|
||||
CUTLASS_MLA = "vllm.v1.attention.backends.mla.cutlass_mla.CutlassMLABackend"
|
||||
FLASHMLA = "vllm.v1.attention.backends.mla.flashmla.FlashMLABackend"
|
||||
FLASHMLA_SPARSE = (
|
||||
"vllm.v1.attention.backends.mla.flashmla_sparse.FlashMLASparseBackend"
|
||||
)
|
||||
FLASH_ATTN_MLA = "vllm.v1.attention.backends.mla.flashattn_mla.FlashAttnMLABackend"
|
||||
PALLAS = "vllm.v1.attention.backends.pallas.PallasAttentionBackend"
|
||||
IPEX = "vllm.v1.attention.backends.ipex.IpexAttentionBackend"
|
||||
NO_ATTENTION = "vllm.v1.attention.backends.no_attention.NoAttentionBackend"
|
||||
FLEX_ATTENTION = "vllm.v1.attention.backends.flex_attention.FlexAttentionBackend"
|
||||
TREE_ATTN = "vllm.v1.attention.backends.tree_attn.TreeAttentionBackend"
|
||||
ROCM_AITER_UNIFIED_ATTN = (
|
||||
"vllm.v1.attention.backends.rocm_aiter_unified_attn."
|
||||
"RocmAiterUnifiedAttentionBackend"
|
||||
)
|
||||
# Placeholder for third-party/custom backends - must be registered before use
|
||||
CUSTOM = ""
|
||||
|
||||
def get_path(self, include_classname: bool = True) -> str:
|
||||
"""Get the class path for this backend (respects overrides).
|
||||
|
||||
Returns:
|
||||
The fully qualified class path string
|
||||
|
||||
Raises:
|
||||
ValueError: If Backend.CUSTOM is used without being registered
|
||||
"""
|
||||
path = _OVERRIDES.get(self, self.value)
|
||||
if not path:
|
||||
raise ValueError(
|
||||
f"Backend {self.name} must be registered before use. "
|
||||
f"Use register_backend(Backend.{self.name}, 'your.module.YourClass')"
|
||||
)
|
||||
if not include_classname:
|
||||
path = path.rsplit(".", 1)[0]
|
||||
return path
|
||||
|
||||
def get_class(self) -> "type[AttentionBackend]":
|
||||
"""Get the backend class (respects overrides).
|
||||
|
||||
Returns:
|
||||
The backend class
|
||||
|
||||
Raises:
|
||||
ImportError: If the backend class cannot be imported
|
||||
ValueError: If Backend.CUSTOM is used without being registered
|
||||
"""
|
||||
return resolve_obj_by_qualname(self.get_path())
|
||||
|
||||
def is_overridden(self) -> bool:
|
||||
"""Check if this backend has been overridden.
|
||||
|
||||
Returns:
|
||||
True if the backend has a registered override
|
||||
"""
|
||||
return self in _OVERRIDES
|
||||
|
||||
def clear_override(self) -> None:
|
||||
"""Clear any override for this backend, reverting to the default."""
|
||||
_OVERRIDES.pop(self, None)
|
||||
|
||||
|
||||
_OVERRIDES: dict[AttentionBackendEnum, str] = {}
|
||||
|
||||
|
||||
def register_backend(
|
||||
backend: AttentionBackendEnum, class_path: str | None = None
|
||||
) -> Callable[[type], type]:
|
||||
"""Register or override a backend implementation.
|
||||
|
||||
Args:
|
||||
backend: The AttentionBackendEnum member to register
|
||||
class_path: Optional class path. If not provided and used as
|
||||
decorator, will be auto-generated from the class.
|
||||
|
||||
Returns:
|
||||
Decorator function if class_path is None, otherwise a no-op
|
||||
|
||||
Examples:
|
||||
# Override an existing backend
|
||||
@register_backend(AttentionBackendEnum.FLASH_ATTN)
|
||||
class MyCustomFlashAttn:
|
||||
...
|
||||
|
||||
# Register a custom third-party backend
|
||||
@register_backend(AttentionBackendEnum.CUSTOM)
|
||||
class MyCustomBackend:
|
||||
...
|
||||
|
||||
# Direct registration
|
||||
register_backend(
|
||||
AttentionBackendEnum.CUSTOM,
|
||||
"my.module.MyCustomBackend"
|
||||
)
|
||||
"""
|
||||
|
||||
def decorator(cls: type) -> type:
|
||||
_OVERRIDES[backend] = f"{cls.__module__}.{cls.__qualname__}"
|
||||
return cls
|
||||
|
||||
if class_path is not None:
|
||||
_OVERRIDES[backend] = class_path
|
||||
return lambda x: x
|
||||
|
||||
return decorator
|
||||
|
||||
|
||||
def backend_to_class_str(backend: _Backend) -> str:
|
||||
"""Get the backend class string
|
||||
# Backwards compatibility alias for plugins
|
||||
class _BackendMeta(type):
|
||||
"""Metaclass to provide deprecation warnings when accessing _Backend."""
|
||||
|
||||
Args:
|
||||
backend: The backend enum value
|
||||
def __getattribute__(cls, name: str):
|
||||
if name not in ("__class__", "__mro__", "__name__"):
|
||||
logger.warning(
|
||||
"_Backend has been renamed to AttentionBackendEnum. "
|
||||
"Please update your code to use AttentionBackendEnum instead. "
|
||||
"_Backend will be removed in a future release."
|
||||
)
|
||||
return getattr(AttentionBackendEnum, name)
|
||||
|
||||
Returns:
|
||||
The backend class string
|
||||
def __getitem__(cls, name: str):
|
||||
logger.warning(
|
||||
"_Backend has been renamed to AttentionBackendEnum. "
|
||||
"Please update your code to use AttentionBackendEnum instead. "
|
||||
"_Backend will be removed in a future release."
|
||||
)
|
||||
return AttentionBackendEnum[name]
|
||||
|
||||
|
||||
class _Backend(metaclass=_BackendMeta):
|
||||
"""Deprecated: Use AttentionBackendEnum instead.
|
||||
|
||||
This class is provided for backwards compatibility with plugins
|
||||
and will be removed in a future release.
|
||||
"""
|
||||
return BACKEND_MAP[backend]
|
||||
|
||||
|
||||
def backend_to_class(backend: _Backend) -> type:
|
||||
"""Get the backend class.
|
||||
|
||||
Args:
|
||||
backend: The backend enum value
|
||||
|
||||
Returns:
|
||||
The backend class
|
||||
"""
|
||||
backend_class_name = backend_to_class_str(backend)
|
||||
return resolve_obj_by_qualname(backend_class_name)
|
||||
|
||||
|
||||
def backend_name_to_enum(backend_name: str) -> _Backend | None:
|
||||
"""
|
||||
Convert a string backend name to a _Backend enum value.
|
||||
|
||||
Returns:
|
||||
_Backend: enum value if backend_name is a valid in-tree type
|
||||
None: otherwise it's an invalid in-tree type or an out-of-tree platform
|
||||
is loaded.
|
||||
"""
|
||||
assert backend_name is not None
|
||||
return _Backend[backend_name] if backend_name in _Backend.__members__ else None
|
||||
pass
|
||||
|
||||
@ -12,7 +12,7 @@ import torch.nn.functional as F
|
||||
import vllm.envs as envs
|
||||
from vllm.attention import AttentionType
|
||||
from vllm.attention.backends.abstract import AttentionBackend, MLAAttentionImpl
|
||||
from vllm.attention.backends.registry import _Backend, backend_name_to_enum
|
||||
from vllm.attention.backends.registry import AttentionBackendEnum
|
||||
from vllm.attention.selector import get_attn_backend
|
||||
from vllm.attention.utils.kv_sharing_utils import validate_kv_sharing_target
|
||||
from vllm.config import CacheConfig, get_current_vllm_config
|
||||
@ -99,40 +99,44 @@ def check_upstream_fa_availability(dtype: torch.dtype):
|
||||
|
||||
|
||||
def maybe_get_vit_flash_attn_backend(
|
||||
attn_backend: _Backend,
|
||||
attn_backend: AttentionBackendEnum,
|
||||
use_upstream_fa: bool,
|
||||
attn_backend_override: _Backend | None = None,
|
||||
) -> tuple[_Backend, Callable | None]:
|
||||
attn_backend_override: AttentionBackendEnum | None = None,
|
||||
) -> tuple[AttentionBackendEnum, Callable | None]:
|
||||
if current_platform.is_rocm():
|
||||
if envs.VLLM_ROCM_USE_AITER and envs.VLLM_ROCM_USE_AITER_MHA and on_gfx9():
|
||||
attn_backend = _Backend.ROCM_AITER_FA
|
||||
attn_backend = AttentionBackendEnum.ROCM_AITER_FA
|
||||
|
||||
elif (
|
||||
check_upstream_fa_availability(torch.get_default_dtype())
|
||||
and on_gfx9()
|
||||
and attn_backend_override is None
|
||||
):
|
||||
attn_backend = _Backend.FLASH_ATTN
|
||||
attn_backend = AttentionBackendEnum.FLASH_ATTN
|
||||
use_upstream_fa = True
|
||||
else:
|
||||
return _Backend.TORCH_SDPA, None
|
||||
return AttentionBackendEnum.TORCH_SDPA, None
|
||||
|
||||
elif current_platform.is_cuda():
|
||||
if attn_backend != _Backend.FLASH_ATTN and check_upstream_fa_availability(
|
||||
torch.get_default_dtype()
|
||||
if (
|
||||
attn_backend != AttentionBackendEnum.FLASH_ATTN
|
||||
and check_upstream_fa_availability(torch.get_default_dtype())
|
||||
):
|
||||
attn_backend = _Backend.FLASH_ATTN
|
||||
attn_backend = AttentionBackendEnum.FLASH_ATTN
|
||||
use_upstream_fa = True
|
||||
elif current_platform.is_xpu():
|
||||
assert attn_backend == _Backend.FLASH_ATTN, (
|
||||
assert attn_backend == AttentionBackendEnum.FLASH_ATTN, (
|
||||
"XPU platform only supports FLASH_ATTN as vision attention backend."
|
||||
)
|
||||
use_upstream_fa = False
|
||||
else:
|
||||
return _Backend.TORCH_SDPA, None
|
||||
return AttentionBackendEnum.TORCH_SDPA, None
|
||||
|
||||
if attn_backend in {_Backend.FLASH_ATTN, _Backend.ROCM_AITER_FA}:
|
||||
if attn_backend == _Backend.ROCM_AITER_FA:
|
||||
if attn_backend in {
|
||||
AttentionBackendEnum.FLASH_ATTN,
|
||||
AttentionBackendEnum.ROCM_AITER_FA,
|
||||
}:
|
||||
if attn_backend == AttentionBackendEnum.ROCM_AITER_FA:
|
||||
from aiter import flash_attn_varlen_func
|
||||
else:
|
||||
if use_upstream_fa:
|
||||
@ -309,7 +313,7 @@ class Attention(nn.Module, AttentionLayerBase):
|
||||
kv_sharing_target_layer_name,
|
||||
**extra_impl_args,
|
||||
)
|
||||
self.backend = backend_name_to_enum(self.attn_backend.get_name())
|
||||
self.backend = AttentionBackendEnum[self.attn_backend.get_name()]
|
||||
self.dtype = dtype
|
||||
|
||||
# For cuda-alike (CUDA and ROCM) and cpu platforms, we control how
|
||||
@ -530,13 +534,13 @@ class MultiHeadAttention(nn.Module):
|
||||
backend
|
||||
if backend
|
||||
in {
|
||||
_Backend.TORCH_SDPA,
|
||||
_Backend.XFORMERS,
|
||||
_Backend.PALLAS,
|
||||
_Backend.ROCM_AITER_FA,
|
||||
_Backend.FLASH_ATTN,
|
||||
AttentionBackendEnum.TORCH_SDPA,
|
||||
AttentionBackendEnum.XFORMERS,
|
||||
AttentionBackendEnum.PALLAS,
|
||||
AttentionBackendEnum.ROCM_AITER_FA,
|
||||
AttentionBackendEnum.FLASH_ATTN,
|
||||
}
|
||||
else _Backend.TORCH_SDPA
|
||||
else AttentionBackendEnum.TORCH_SDPA
|
||||
)
|
||||
|
||||
self.attn_backend, self._flash_attn_varlen_func = (
|
||||
@ -547,17 +551,23 @@ class MultiHeadAttention(nn.Module):
|
||||
)
|
||||
)
|
||||
|
||||
if self.attn_backend == _Backend.XFORMERS and not check_xformers_availability():
|
||||
self.attn_backend = _Backend.TORCH_SDPA
|
||||
if (
|
||||
self.attn_backend == AttentionBackendEnum.XFORMERS
|
||||
and not check_xformers_availability()
|
||||
):
|
||||
self.attn_backend = AttentionBackendEnum.TORCH_SDPA
|
||||
|
||||
self.is_flash_attn_backend = self.attn_backend in {
|
||||
_Backend.FLASH_ATTN,
|
||||
_Backend.ROCM_AITER_FA,
|
||||
AttentionBackendEnum.FLASH_ATTN,
|
||||
AttentionBackendEnum.ROCM_AITER_FA,
|
||||
}
|
||||
|
||||
# this condition is just to make sure that the
|
||||
# use_upstream_fa in the log is correct
|
||||
if current_platform.is_rocm() and self.attn_backend == _Backend.FLASH_ATTN:
|
||||
if (
|
||||
current_platform.is_rocm()
|
||||
and self.attn_backend == AttentionBackendEnum.FLASH_ATTN
|
||||
):
|
||||
use_upstream_fa = True
|
||||
|
||||
logger.info_once(
|
||||
@ -606,17 +616,17 @@ class MultiHeadAttention(nn.Module):
|
||||
max_seqlen_k=kv_len,
|
||||
softmax_scale=self.scale,
|
||||
)
|
||||
elif self.attn_backend == _Backend.XFORMERS:
|
||||
elif self.attn_backend == AttentionBackendEnum.XFORMERS:
|
||||
from xformers import ops as xops
|
||||
|
||||
out = xops.memory_efficient_attention_forward(
|
||||
query, key, value, scale=self.scale
|
||||
)
|
||||
elif self.attn_backend == _Backend.TORCH_SDPA:
|
||||
elif self.attn_backend == AttentionBackendEnum.TORCH_SDPA:
|
||||
query, key, value = (x.transpose(1, 2) for x in (query, key, value))
|
||||
out = F.scaled_dot_product_attention(query, key, value, scale=self.scale)
|
||||
out = out.transpose(1, 2)
|
||||
elif self.attn_backend == _Backend.PALLAS:
|
||||
elif self.attn_backend == AttentionBackendEnum.PALLAS:
|
||||
query, key, value = (x.transpose(1, 2) for x in (query, key, value))
|
||||
from torch_xla.experimental.custom_kernel import flash_attention
|
||||
|
||||
|
||||
@ -4,14 +4,15 @@
|
||||
import os
|
||||
from collections.abc import Generator
|
||||
from contextlib import contextmanager
|
||||
from dataclasses import dataclass
|
||||
from functools import cache
|
||||
from typing import cast, get_args
|
||||
|
||||
import torch
|
||||
|
||||
import vllm.envs as envs
|
||||
from vllm.attention.backends.abstract import AttentionBackend
|
||||
from vllm.attention.backends.registry import _Backend, backend_name_to_enum
|
||||
from vllm.attention.backends.registry import AttentionBackendEnum
|
||||
from vllm.config.cache import CacheDType
|
||||
from vllm.logger import init_logger
|
||||
from vllm.utils import STR_BACKEND_ENV_VAR
|
||||
from vllm.utils.import_utils import resolve_obj_by_qualname
|
||||
@ -19,18 +20,18 @@ from vllm.utils.import_utils import resolve_obj_by_qualname
|
||||
logger = init_logger(__name__)
|
||||
|
||||
|
||||
def get_env_variable_attn_backend() -> _Backend | None:
|
||||
def get_env_variable_attn_backend() -> AttentionBackendEnum | None:
|
||||
"""
|
||||
Get the backend override specified by the vLLM attention
|
||||
backend environment variable, if one is specified.
|
||||
|
||||
Returns:
|
||||
|
||||
* _Backend enum value if an override is specified
|
||||
* AttentionBackendEnum value if an override is specified
|
||||
* None otherwise
|
||||
"""
|
||||
backend_name = os.environ.get(STR_BACKEND_ENV_VAR)
|
||||
return None if backend_name is None else backend_name_to_enum(backend_name)
|
||||
return None if backend_name is None else AttentionBackendEnum[backend_name]
|
||||
|
||||
|
||||
# Global state allows a particular choice of backend
|
||||
@ -40,10 +41,10 @@ def get_env_variable_attn_backend() -> _Backend | None:
|
||||
#
|
||||
# THIS SELECTION TAKES PRECEDENCE OVER THE
|
||||
# VLLM_ATTENTION_BACKEND ENVIRONMENT VARIABLE
|
||||
forced_attn_backend: _Backend | None = None
|
||||
forced_attn_backend: AttentionBackendEnum | None = None
|
||||
|
||||
|
||||
def global_force_attn_backend(attn_backend: _Backend | None) -> None:
|
||||
def global_force_attn_backend(attn_backend: AttentionBackendEnum | None) -> None:
|
||||
"""
|
||||
Force all attention operations to use a specified backend.
|
||||
|
||||
@ -58,7 +59,7 @@ def global_force_attn_backend(attn_backend: _Backend | None) -> None:
|
||||
forced_attn_backend = attn_backend
|
||||
|
||||
|
||||
def get_global_forced_attn_backend() -> _Backend | None:
|
||||
def get_global_forced_attn_backend() -> AttentionBackendEnum | None:
|
||||
"""
|
||||
Get the currently-forced choice of attention backend,
|
||||
or None if auto-selection is currently enabled.
|
||||
@ -66,78 +67,28 @@ def get_global_forced_attn_backend() -> _Backend | None:
|
||||
return forced_attn_backend
|
||||
|
||||
|
||||
@dataclass(frozen=True)
|
||||
class _IsSupported:
|
||||
can_import: bool
|
||||
head_size: bool
|
||||
dtype: bool
|
||||
|
||||
def __bool__(self) -> bool:
|
||||
return self.can_import and self.head_size and self.dtype
|
||||
|
||||
|
||||
def is_attn_backend_supported(
|
||||
attn_backend: str | type[AttentionBackend],
|
||||
head_size: int,
|
||||
dtype: torch.dtype,
|
||||
*,
|
||||
allow_import_error: bool = True,
|
||||
) -> _IsSupported:
|
||||
if isinstance(attn_backend, str):
|
||||
try:
|
||||
attn_backend = resolve_obj_by_qualname(attn_backend)
|
||||
except ImportError:
|
||||
if not allow_import_error:
|
||||
raise
|
||||
|
||||
return _IsSupported(can_import=False, head_size=False, dtype=False)
|
||||
|
||||
assert isinstance(attn_backend, type)
|
||||
|
||||
# TODO: Update the interface once V0 is removed
|
||||
if get_supported_head_sizes := getattr(
|
||||
attn_backend, "get_supported_head_sizes", None
|
||||
):
|
||||
is_head_size_supported = head_size in get_supported_head_sizes()
|
||||
elif validate_head_size := getattr(attn_backend, "validate_head_size", None):
|
||||
try:
|
||||
validate_head_size(head_size)
|
||||
is_head_size_supported = True
|
||||
except Exception:
|
||||
is_head_size_supported = False
|
||||
else:
|
||||
raise NotImplementedError(
|
||||
f"{attn_backend.__name__} does not support head size validation"
|
||||
)
|
||||
|
||||
if get_supported_dtypes := getattr(attn_backend, "get_supported_dtypes", None):
|
||||
is_dtype_supported = dtype in get_supported_dtypes()
|
||||
else:
|
||||
raise NotImplementedError(
|
||||
f"{attn_backend.__name__} does not support dtype validation"
|
||||
)
|
||||
|
||||
return _IsSupported(
|
||||
can_import=True,
|
||||
head_size=is_head_size_supported,
|
||||
dtype=is_dtype_supported,
|
||||
)
|
||||
|
||||
|
||||
def get_attn_backend(
|
||||
head_size: int,
|
||||
dtype: torch.dtype,
|
||||
kv_cache_dtype: str | None,
|
||||
block_size: int,
|
||||
block_size: int | None,
|
||||
use_mla: bool = False,
|
||||
has_sink: bool = False,
|
||||
use_sparse: bool = False,
|
||||
) -> type[AttentionBackend]:
|
||||
"""Selects which attention backend to use and lazily imports it."""
|
||||
|
||||
if kv_cache_dtype is not None:
|
||||
valid_cache_dtypes = get_args(CacheDType)
|
||||
assert kv_cache_dtype in valid_cache_dtypes, (
|
||||
f"Invalid kv_cache_dtype: {kv_cache_dtype}. "
|
||||
f"Valid values are: {valid_cache_dtypes}"
|
||||
)
|
||||
|
||||
return _cached_get_attn_backend(
|
||||
head_size=head_size,
|
||||
dtype=dtype,
|
||||
kv_cache_dtype=kv_cache_dtype,
|
||||
kv_cache_dtype=cast(CacheDType | None, kv_cache_dtype),
|
||||
block_size=block_size,
|
||||
use_mla=use_mla,
|
||||
has_sink=has_sink,
|
||||
@ -149,8 +100,8 @@ def get_attn_backend(
|
||||
def _cached_get_attn_backend(
|
||||
head_size: int,
|
||||
dtype: torch.dtype,
|
||||
kv_cache_dtype: str | None,
|
||||
block_size: int,
|
||||
kv_cache_dtype: CacheDType | None,
|
||||
block_size: int | None,
|
||||
use_mla: bool = False,
|
||||
has_sink: bool = False,
|
||||
use_sparse: bool = False,
|
||||
@ -161,7 +112,9 @@ def _cached_get_attn_backend(
|
||||
# THIS SELECTION OVERRIDES THE VLLM_ATTENTION_BACKEND
|
||||
# ENVIRONMENT VARIABLE.
|
||||
selected_backend = None
|
||||
backend_by_global_setting: _Backend | None = get_global_forced_attn_backend()
|
||||
backend_by_global_setting: AttentionBackendEnum | None = (
|
||||
get_global_forced_attn_backend()
|
||||
)
|
||||
if backend_by_global_setting is not None:
|
||||
selected_backend = backend_by_global_setting
|
||||
else:
|
||||
@ -177,12 +130,13 @@ def _cached_get_attn_backend(
|
||||
STR_BACKEND_ENV_VAR,
|
||||
)
|
||||
backend_by_env_var = backend_by_env_var.removesuffix("_VLLM_V1")
|
||||
selected_backend = backend_name_to_enum(backend_by_env_var)
|
||||
if selected_backend is None:
|
||||
try:
|
||||
selected_backend = AttentionBackendEnum[backend_by_env_var]
|
||||
except KeyError as e:
|
||||
raise ValueError(
|
||||
f"Invalid attention backend: '{backend_by_env_var}'. "
|
||||
f"Valid backends are: {list(_Backend.__members__.keys())}"
|
||||
)
|
||||
f"Invalid attention backend: '{backend_by_env_var}'. Valid "
|
||||
f"backends are: {list(AttentionBackendEnum.__members__.keys())}"
|
||||
) from e
|
||||
|
||||
# get device-specific attn_backend
|
||||
from vllm.platforms import current_platform
|
||||
@ -202,12 +156,26 @@ def _cached_get_attn_backend(
|
||||
raise ValueError(
|
||||
f"Invalid attention backend for {current_platform.device_name}"
|
||||
)
|
||||
return resolve_obj_by_qualname(attention_cls)
|
||||
backend = resolve_obj_by_qualname(attention_cls)
|
||||
|
||||
# Adjust kv cache layout if the selected backend requires a specific one
|
||||
required_layout = backend.get_required_kv_cache_layout()
|
||||
if required_layout is not None:
|
||||
from vllm.v1.attention.backends.utils import set_kv_cache_layout
|
||||
|
||||
set_kv_cache_layout(required_layout)
|
||||
logger.info(
|
||||
"Using %s KV cache layout for %s backend.",
|
||||
required_layout,
|
||||
backend.get_name(),
|
||||
)
|
||||
|
||||
return backend
|
||||
|
||||
|
||||
@contextmanager
|
||||
def global_force_attn_backend_context_manager(
|
||||
attn_backend: _Backend,
|
||||
attn_backend: AttentionBackendEnum,
|
||||
) -> Generator[None, None, None]:
|
||||
"""
|
||||
Globally force a vLLM attention backend override within a
|
||||
|
||||
@ -21,7 +21,15 @@ else:
|
||||
logger = init_logger(__name__)
|
||||
|
||||
BlockSize = Literal[1, 8, 16, 32, 64, 128, 256]
|
||||
CacheDType = Literal["auto", "bfloat16", "fp8", "fp8_e4m3", "fp8_e5m2", "fp8_inc"]
|
||||
CacheDType = Literal[
|
||||
"auto",
|
||||
"bfloat16",
|
||||
"fp8",
|
||||
"fp8_e4m3",
|
||||
"fp8_e5m2",
|
||||
"fp8_inc",
|
||||
"fp8_ds_mla",
|
||||
]
|
||||
MambaDType = Literal["auto", "float32"]
|
||||
PrefixCachingHashAlgo = Literal["sha256", "sha256_cbor"]
|
||||
KVOffloadingBackend = Literal["native", "lmcache"]
|
||||
|
||||
@ -45,7 +45,7 @@ if TYPE_CHECKING:
|
||||
|
||||
import vllm.model_executor.layers.quantization as me_quant
|
||||
import vllm.model_executor.models as me_models
|
||||
from vllm.attention.backends.registry import _Backend
|
||||
from vllm.attention.backends.registry import AttentionBackendEnum
|
||||
from vllm.config.load import LoadConfig
|
||||
from vllm.config.parallel import ParallelConfig
|
||||
from vllm.model_executor.layers.quantization import QuantizationMethods
|
||||
@ -53,7 +53,7 @@ if TYPE_CHECKING:
|
||||
else:
|
||||
PretrainedConfig = Any
|
||||
|
||||
_Backend = Any
|
||||
AttentionBackendEnum = Any
|
||||
me_quant = LazyLoader(
|
||||
"model_executor", globals(), "vllm.model_executor.layers.quantization"
|
||||
)
|
||||
@ -302,7 +302,7 @@ class ModelConfig:
|
||||
mm_processor_cache_type: InitVar[MMCacheType | None] = None
|
||||
mm_shm_cache_max_object_size_mb: InitVar[int | None] = None
|
||||
mm_encoder_tp_mode: InitVar[MMEncoderTPMode | None] = None
|
||||
mm_encoder_attn_backend: InitVar[_Backend | str | None] = None
|
||||
mm_encoder_attn_backend: InitVar[AttentionBackendEnum | str | None] = None
|
||||
interleave_mm_strings: InitVar[bool | None] = None
|
||||
skip_mm_profiling: InitVar[bool | None] = None
|
||||
video_pruning_rate: InitVar[float | None] = None
|
||||
@ -420,7 +420,7 @@ class ModelConfig:
|
||||
mm_processor_cache_type: MMCacheType | None,
|
||||
mm_shm_cache_max_object_size_mb: int | None,
|
||||
mm_encoder_tp_mode: MMEncoderTPMode | None,
|
||||
mm_encoder_attn_backend: _Backend | str | None,
|
||||
mm_encoder_attn_backend: AttentionBackendEnum | str | None,
|
||||
interleave_mm_strings: bool | None,
|
||||
skip_mm_profiling: bool | None,
|
||||
video_pruning_rate: float | None,
|
||||
|
||||
@ -11,9 +11,9 @@ from pydantic.dataclasses import dataclass
|
||||
from vllm.config.utils import config
|
||||
|
||||
if TYPE_CHECKING:
|
||||
from vllm.attention.backends.registry import _Backend
|
||||
from vllm.attention.backends.registry import AttentionBackendEnum
|
||||
else:
|
||||
_Backend = Any
|
||||
AttentionBackendEnum = Any
|
||||
|
||||
|
||||
@dataclass
|
||||
@ -125,10 +125,10 @@ class MultiModalConfig:
|
||||
DP (which is controlled by `--data-parallel-size`).
|
||||
This is only supported on a per-model basis and falls back to
|
||||
`"weights"` if the encoder does not support DP."""
|
||||
mm_encoder_attn_backend: _Backend | None = None
|
||||
mm_encoder_attn_backend: AttentionBackendEnum | None = None
|
||||
"""Optional override for the multi-modal encoder attention backend when
|
||||
using vision transformers. Accepts any value from
|
||||
`vllm.attention.backends.registry._Backend` (e.g. `FLASH_ATTN`)."""
|
||||
`vllm.attention.backends.registry.AttentionBackendEnum` (e.g. `FLASH_ATTN`)."""
|
||||
interleave_mm_strings: bool = False
|
||||
"""Enable fully interleaved support for multimodal prompts, while using
|
||||
--chat-template-content-format=string."""
|
||||
@ -167,26 +167,16 @@ class MultiModalConfig:
|
||||
|
||||
@field_validator("mm_encoder_attn_backend", mode="before")
|
||||
@classmethod
|
||||
def _validate_mm_encoder_attn_backend(cls, value: object) -> _Backend | None:
|
||||
from vllm.attention.backends.registry import (
|
||||
_Backend as BackendEnum,
|
||||
)
|
||||
from vllm.attention.backends.registry import (
|
||||
backend_name_to_enum,
|
||||
)
|
||||
|
||||
if value is None or isinstance(value, BackendEnum):
|
||||
def _validate_mm_encoder_attn_backend(
|
||||
cls, value: str | AttentionBackendEnum | None
|
||||
) -> AttentionBackendEnum | None:
|
||||
if value is None or isinstance(value, AttentionBackendEnum):
|
||||
return value
|
||||
|
||||
if isinstance(value, str):
|
||||
candidate = backend_name_to_enum(value.upper())
|
||||
if candidate is not None:
|
||||
return candidate
|
||||
|
||||
valid_backends = ", ".join(sorted(BackendEnum.__members__.keys()))
|
||||
raise ValueError(
|
||||
f"Invalid mm encoder attention backend. Expected one of: {valid_backends}."
|
||||
assert isinstance(value, str), (
|
||||
"mm_encoder_attn_backend must be a string or an AttentionBackendEnum."
|
||||
)
|
||||
return AttentionBackendEnum[value.upper()]
|
||||
|
||||
@model_validator(mode="after")
|
||||
def _validate_multimodal_config(self):
|
||||
|
||||
@ -21,7 +21,7 @@ import torch
|
||||
import zmq
|
||||
|
||||
from vllm import envs
|
||||
from vllm.attention.backends.registry import _Backend, backend_name_to_enum
|
||||
from vllm.attention.backends.registry import AttentionBackendEnum
|
||||
from vllm.attention.selector import get_attn_backend
|
||||
from vllm.config import VllmConfig
|
||||
from vllm.distributed.kv_transfer.kv_connector.v1.base import (
|
||||
@ -876,9 +876,9 @@ class NixlConnectorWorker:
|
||||
use_mla=self.use_mla,
|
||||
)
|
||||
self.backend_name = backend.get_name()
|
||||
attn_backend = backend_name_to_enum(self.backend_name)
|
||||
self._use_flashinfer = attn_backend == _Backend.FLASHINFER
|
||||
self._use_pallas = attn_backend == _Backend.PALLAS
|
||||
attn_backend = AttentionBackendEnum[self.backend_name]
|
||||
self._use_flashinfer = attn_backend == AttentionBackendEnum.FLASHINFER
|
||||
self._use_pallas = attn_backend == AttentionBackendEnum.PALLAS
|
||||
self.kv_cache_layout = get_kv_cache_layout()
|
||||
self.host_buffer_kv_cache_layout = self.kv_cache_layout
|
||||
logger.debug("Detected attention backend %s", self.backend_name)
|
||||
|
||||
@ -32,7 +32,7 @@ from pydantic.fields import FieldInfo
|
||||
from typing_extensions import TypeIs, deprecated
|
||||
|
||||
import vllm.envs as envs
|
||||
from vllm.attention.backends.registry import _Backend
|
||||
from vllm.attention.backends.registry import AttentionBackendEnum
|
||||
from vllm.config import (
|
||||
CacheConfig,
|
||||
CompilationConfig,
|
||||
@ -462,7 +462,7 @@ class EngineArgs:
|
||||
MultiModalConfig.mm_shm_cache_max_object_size_mb
|
||||
)
|
||||
mm_encoder_tp_mode: MMEncoderTPMode = MultiModalConfig.mm_encoder_tp_mode
|
||||
mm_encoder_attn_backend: _Backend | str | None = (
|
||||
mm_encoder_attn_backend: AttentionBackendEnum | str | None = (
|
||||
MultiModalConfig.mm_encoder_attn_backend
|
||||
)
|
||||
io_processor_plugin: str | None = None
|
||||
|
||||
@ -626,14 +626,14 @@ environment_variables: dict[str, Callable[[], Any]] = {
|
||||
# - "FLASH_ATTN_MLA": use FlashAttention for MLA
|
||||
# - "FLASHINFER_MLA": use FlashInfer for MLA
|
||||
# - "CUTLASS_MLA": use CUTLASS for MLA
|
||||
# All possible options loaded dynamically from _Backend enum
|
||||
# All possible options loaded dynamically from AttentionBackendEnum
|
||||
"VLLM_ATTENTION_BACKEND": env_with_choices(
|
||||
"VLLM_ATTENTION_BACKEND",
|
||||
None,
|
||||
lambda: list(
|
||||
__import__(
|
||||
"vllm.attention.backends.registry", fromlist=["_Backend"]
|
||||
)._Backend.__members__.keys()
|
||||
"vllm.attention.backends.registry", fromlist=["AttentionBackendEnum"]
|
||||
).AttentionBackendEnum.__members__.keys()
|
||||
),
|
||||
),
|
||||
# If set, vllm will use flashinfer sampler
|
||||
|
||||
@ -9,7 +9,7 @@ import torch.nn.functional as F
|
||||
from torch.nn import LayerNorm
|
||||
from transformers.models.qwen2_vl import Qwen2VLProcessor
|
||||
|
||||
from vllm.attention.backends.registry import _Backend
|
||||
from vllm.attention.backends.registry import AttentionBackendEnum
|
||||
from vllm.attention.layer import (
|
||||
check_upstream_fa_availability,
|
||||
maybe_get_vit_flash_attn_backend,
|
||||
@ -256,7 +256,7 @@ class DotsVisionAttention(nn.Module):
|
||||
quant_config: QuantizationConfig | None = None,
|
||||
prefix: str = "",
|
||||
use_data_parallel: bool = False,
|
||||
attn_backend_override: _Backend | None = None,
|
||||
attn_backend_override: AttentionBackendEnum | None = None,
|
||||
) -> None:
|
||||
super().__init__()
|
||||
|
||||
@ -303,17 +303,17 @@ class DotsVisionAttention(nn.Module):
|
||||
)
|
||||
)
|
||||
if self.attn_backend not in {
|
||||
_Backend.FLASH_ATTN,
|
||||
_Backend.TORCH_SDPA,
|
||||
_Backend.XFORMERS,
|
||||
_Backend.ROCM_AITER_FA,
|
||||
AttentionBackendEnum.FLASH_ATTN,
|
||||
AttentionBackendEnum.TORCH_SDPA,
|
||||
AttentionBackendEnum.XFORMERS,
|
||||
AttentionBackendEnum.ROCM_AITER_FA,
|
||||
}:
|
||||
raise RuntimeError(
|
||||
f"Unsupported vision attention backend: {self.attn_backend}"
|
||||
)
|
||||
self.is_flash_attn_backend = self.attn_backend in {
|
||||
_Backend.FLASH_ATTN,
|
||||
_Backend.ROCM_AITER_FA,
|
||||
AttentionBackendEnum.FLASH_ATTN,
|
||||
AttentionBackendEnum.ROCM_AITER_FA,
|
||||
}
|
||||
|
||||
def forward(
|
||||
@ -361,7 +361,7 @@ class DotsVisionAttention(nn.Module):
|
||||
self.num_attention_heads_per_partition,
|
||||
self.hidden_size_per_attention_head,
|
||||
)
|
||||
elif self.attn_backend == _Backend.TORCH_SDPA:
|
||||
elif self.attn_backend == AttentionBackendEnum.TORCH_SDPA:
|
||||
outputs = []
|
||||
for i in range(1, len(cu_seqlens)):
|
||||
s = int(cu_seqlens[i - 1])
|
||||
@ -373,7 +373,7 @@ class DotsVisionAttention(nn.Module):
|
||||
out_i = out_i.permute(0, 2, 1, 3)
|
||||
outputs.append(out_i)
|
||||
context_layer = torch.cat(outputs, dim=1) if outputs else q[:, :0]
|
||||
elif self.attn_backend == _Backend.XFORMERS:
|
||||
elif self.attn_backend == AttentionBackendEnum.XFORMERS:
|
||||
from xformers import ops as xops
|
||||
from xformers.ops.fmha.attn_bias import BlockDiagonalMask
|
||||
|
||||
@ -514,7 +514,7 @@ class DotsVisionBlock(nn.Module):
|
||||
quant_config: QuantizationConfig | None = None,
|
||||
prefix: str = "",
|
||||
use_data_parallel: bool = False,
|
||||
attn_backend_override: _Backend | None = None,
|
||||
attn_backend_override: AttentionBackendEnum | None = None,
|
||||
):
|
||||
super().__init__()
|
||||
|
||||
@ -567,7 +567,7 @@ class DotsVisionTransformer(nn.Module):
|
||||
require_post_norm: bool | None = None,
|
||||
prefix: str = "",
|
||||
use_data_parallel: bool = False,
|
||||
attn_backend_override: _Backend | None = None,
|
||||
attn_backend_override: AttentionBackendEnum | None = None,
|
||||
) -> None:
|
||||
super().__init__()
|
||||
self.config = config
|
||||
@ -582,10 +582,11 @@ class DotsVisionTransformer(nn.Module):
|
||||
dtype=torch.get_default_dtype(),
|
||||
attn_backend_override=attn_backend_override,
|
||||
)
|
||||
if self.attn_backend != _Backend.FLASH_ATTN and check_upstream_fa_availability(
|
||||
torch.get_default_dtype()
|
||||
if (
|
||||
self.attn_backend != AttentionBackendEnum.FLASH_ATTN
|
||||
and check_upstream_fa_availability(torch.get_default_dtype())
|
||||
):
|
||||
self.attn_backend = _Backend.FLASH_ATTN
|
||||
self.attn_backend = AttentionBackendEnum.FLASH_ATTN
|
||||
self.out_hidden_size = config.hidden_size
|
||||
# Keep blocks for compatibility with other vision towers
|
||||
num_layers = (
|
||||
@ -666,11 +667,11 @@ class DotsVisionTransformer(nn.Module):
|
||||
) -> tuple[int | None, list[int] | None]:
|
||||
max_seqlen, seqlens = None, None
|
||||
if (
|
||||
self.attn_backend == _Backend.FLASH_ATTN
|
||||
or self.attn_backend == _Backend.ROCM_AITER_FA
|
||||
self.attn_backend == AttentionBackendEnum.FLASH_ATTN
|
||||
or self.attn_backend == AttentionBackendEnum.ROCM_AITER_FA
|
||||
):
|
||||
max_seqlen = (cu_seqlens[1:] - cu_seqlens[:-1]).max().item()
|
||||
elif self.attn_backend == _Backend.XFORMERS:
|
||||
elif self.attn_backend == AttentionBackendEnum.XFORMERS:
|
||||
seqlens = (cu_seqlens[1:] - cu_seqlens[:-1]).tolist()
|
||||
return max_seqlen, seqlens
|
||||
|
||||
|
||||
@ -36,7 +36,7 @@ import torch.nn.functional as F
|
||||
from einops import rearrange, repeat
|
||||
from transformers import BatchFeature, PretrainedConfig
|
||||
|
||||
from vllm.attention.backends.registry import _Backend
|
||||
from vllm.attention.backends.registry import AttentionBackendEnum
|
||||
from vllm.attention.layer import (
|
||||
check_upstream_fa_availability,
|
||||
maybe_get_vit_flash_attn_backend,
|
||||
@ -164,7 +164,7 @@ class Ernie4_5_VisionAttention(nn.Module):
|
||||
projection_size: int,
|
||||
quant_config: QuantizationConfig | None = None,
|
||||
prefix: str = "",
|
||||
attn_backend_override: _Backend | None = None,
|
||||
attn_backend_override: AttentionBackendEnum | None = None,
|
||||
) -> None:
|
||||
super().__init__()
|
||||
# Per attention head and per partition values.
|
||||
@ -211,17 +211,17 @@ class Ernie4_5_VisionAttention(nn.Module):
|
||||
)
|
||||
|
||||
if self.attn_backend not in {
|
||||
_Backend.FLASH_ATTN,
|
||||
_Backend.TORCH_SDPA,
|
||||
_Backend.XFORMERS,
|
||||
_Backend.ROCM_AITER_FA,
|
||||
AttentionBackendEnum.FLASH_ATTN,
|
||||
AttentionBackendEnum.TORCH_SDPA,
|
||||
AttentionBackendEnum.XFORMERS,
|
||||
AttentionBackendEnum.ROCM_AITER_FA,
|
||||
}:
|
||||
raise RuntimeError(
|
||||
f"Ernie45-VL does not support {self.attn_backend} backend now."
|
||||
)
|
||||
self.is_flash_attn_backend = self.attn_backend in {
|
||||
_Backend.FLASH_ATTN,
|
||||
_Backend.ROCM_AITER_FA,
|
||||
AttentionBackendEnum.FLASH_ATTN,
|
||||
AttentionBackendEnum.ROCM_AITER_FA,
|
||||
}
|
||||
|
||||
def split_qkv(self, qkv: torch.Tensor) -> tuple[torch.Tensor, ...]:
|
||||
@ -291,7 +291,7 @@ class Ernie4_5_VisionAttention(nn.Module):
|
||||
context_layer = rearrange(
|
||||
output, "(b s) h d -> s b (h d)", b=batch_size
|
||||
).contiguous()
|
||||
elif self.attn_backend == _Backend.TORCH_SDPA:
|
||||
elif self.attn_backend == AttentionBackendEnum.TORCH_SDPA:
|
||||
# Execute attention entry by entry for speed & less VRAM.
|
||||
outputs = []
|
||||
for i in range(1, len(cu_seqlens)):
|
||||
@ -310,7 +310,7 @@ class Ernie4_5_VisionAttention(nn.Module):
|
||||
context_layer = rearrange(
|
||||
context_layer, "b s h d -> s b (h d)"
|
||||
).contiguous()
|
||||
elif self.attn_backend == _Backend.XFORMERS:
|
||||
elif self.attn_backend == AttentionBackendEnum.XFORMERS:
|
||||
from xformers import ops as xops
|
||||
from xformers.ops.fmha.attn_bias import BlockDiagonalMask
|
||||
|
||||
@ -370,7 +370,7 @@ class Ernie4_5_VisionBlock(nn.Module):
|
||||
norm_layer: Callable[[int], nn.Module] | None = None,
|
||||
quant_config: QuantizationConfig | None = None,
|
||||
prefix: str = "",
|
||||
attn_backend_override: _Backend | None = None,
|
||||
attn_backend_override: AttentionBackendEnum | None = None,
|
||||
) -> None:
|
||||
super().__init__()
|
||||
|
||||
@ -463,7 +463,7 @@ class Ernie4_5_VisionTransformer(nn.Module):
|
||||
norm_eps: float = 1e-6,
|
||||
quant_config: QuantizationConfig | None = None,
|
||||
prefix: str = "",
|
||||
attn_backend_override: _Backend | None = None,
|
||||
attn_backend_override: AttentionBackendEnum | None = None,
|
||||
) -> None:
|
||||
super().__init__()
|
||||
patch_size = vision_config.patch_size
|
||||
@ -515,10 +515,11 @@ class Ernie4_5_VisionTransformer(nn.Module):
|
||||
dtype=torch.get_default_dtype(),
|
||||
attn_backend_override=attn_backend_override,
|
||||
)
|
||||
if self.attn_backend != _Backend.FLASH_ATTN and check_upstream_fa_availability(
|
||||
torch.get_default_dtype()
|
||||
if (
|
||||
self.attn_backend != AttentionBackendEnum.FLASH_ATTN
|
||||
and check_upstream_fa_availability(torch.get_default_dtype())
|
||||
):
|
||||
self.attn_backend = _Backend.FLASH_ATTN
|
||||
self.attn_backend = AttentionBackendEnum.FLASH_ATTN
|
||||
|
||||
@property
|
||||
def dtype(self) -> torch.dtype:
|
||||
@ -565,11 +566,11 @@ class Ernie4_5_VisionTransformer(nn.Module):
|
||||
) -> tuple[int | None, list[int] | None]:
|
||||
max_seqlen, seqlens = None, None
|
||||
if (
|
||||
self.attn_backend == _Backend.FLASH_ATTN
|
||||
or self.attn_backend == _Backend.ROCM_AITER_FA
|
||||
self.attn_backend == AttentionBackendEnum.FLASH_ATTN
|
||||
or self.attn_backend == AttentionBackendEnum.ROCM_AITER_FA
|
||||
):
|
||||
max_seqlen = (cu_seqlens[1:] - cu_seqlens[:-1]).max().item()
|
||||
elif self.attn_backend == _Backend.XFORMERS:
|
||||
elif self.attn_backend == AttentionBackendEnum.XFORMERS:
|
||||
seqlens = (cu_seqlens[1:] - cu_seqlens[:-1]).tolist()
|
||||
return max_seqlen, seqlens
|
||||
|
||||
|
||||
@ -46,7 +46,7 @@ from transformers.models.glm4v.image_processing_glm4v import (
|
||||
from transformers.models.glm4v.video_processing_glm4v import Glm4vVideoProcessor
|
||||
from transformers.video_utils import VideoMetadata
|
||||
|
||||
from vllm.attention.backends.registry import _Backend
|
||||
from vllm.attention.backends.registry import AttentionBackendEnum
|
||||
from vllm.attention.layer import (
|
||||
check_upstream_fa_availability,
|
||||
maybe_get_vit_flash_attn_backend,
|
||||
@ -252,7 +252,7 @@ class Glm4vVisionAttention(nn.Module):
|
||||
quant_config: QuantizationConfig | None = None,
|
||||
prefix: str = "",
|
||||
use_data_parallel: bool = False,
|
||||
attn_backend_override: _Backend | None = None,
|
||||
attn_backend_override: AttentionBackendEnum | None = None,
|
||||
) -> None:
|
||||
super().__init__()
|
||||
# Per attention head and per partition values.
|
||||
@ -306,18 +306,18 @@ class Glm4vVisionAttention(nn.Module):
|
||||
)
|
||||
|
||||
if self.attn_backend not in {
|
||||
_Backend.FLASH_ATTN,
|
||||
_Backend.TORCH_SDPA,
|
||||
_Backend.XFORMERS,
|
||||
_Backend.ROCM_AITER_FA,
|
||||
AttentionBackendEnum.FLASH_ATTN,
|
||||
AttentionBackendEnum.TORCH_SDPA,
|
||||
AttentionBackendEnum.XFORMERS,
|
||||
AttentionBackendEnum.ROCM_AITER_FA,
|
||||
}:
|
||||
raise RuntimeError(
|
||||
f"GLM-4V does not support {self.attn_backend} backend now."
|
||||
)
|
||||
|
||||
self.is_flash_attn_backend = self.attn_backend in {
|
||||
_Backend.FLASH_ATTN,
|
||||
_Backend.ROCM_AITER_FA,
|
||||
AttentionBackendEnum.FLASH_ATTN,
|
||||
AttentionBackendEnum.ROCM_AITER_FA,
|
||||
}
|
||||
|
||||
def split_qkv(self, qkv: torch.Tensor) -> tuple[torch.Tensor, ...]:
|
||||
@ -377,7 +377,7 @@ class Glm4vVisionAttention(nn.Module):
|
||||
context_layer = rearrange(
|
||||
output, "(b s) h d -> s b (h d)", b=batch_size
|
||||
).contiguous()
|
||||
elif self.attn_backend == _Backend.TORCH_SDPA:
|
||||
elif self.attn_backend == AttentionBackendEnum.TORCH_SDPA:
|
||||
# Execute attention entry by entry for speed & less VRAM.
|
||||
outputs = []
|
||||
for i in range(1, len(cu_seqlens)):
|
||||
@ -396,7 +396,7 @@ class Glm4vVisionAttention(nn.Module):
|
||||
context_layer = rearrange(
|
||||
context_layer, "b s h d -> s b (h d)"
|
||||
).contiguous()
|
||||
elif self.attn_backend == _Backend.XFORMERS:
|
||||
elif self.attn_backend == AttentionBackendEnum.XFORMERS:
|
||||
from xformers import ops as xops
|
||||
from xformers.ops.fmha.attn_bias import BlockDiagonalMask
|
||||
|
||||
@ -425,7 +425,7 @@ class Glm4vVisionBlock(nn.Module):
|
||||
quant_config: QuantizationConfig | None = None,
|
||||
prefix: str = "",
|
||||
use_data_parallel: bool = False,
|
||||
attn_backend_override: _Backend | None = None,
|
||||
attn_backend_override: AttentionBackendEnum | None = None,
|
||||
) -> None:
|
||||
super().__init__()
|
||||
if norm_layer is None:
|
||||
@ -703,7 +703,7 @@ class Glm4vVisionTransformer(nn.Module):
|
||||
quant_config: QuantizationConfig | None = None,
|
||||
prefix: str = "",
|
||||
use_data_parallel: bool = False,
|
||||
attn_backend_override: _Backend | None = None,
|
||||
attn_backend_override: AttentionBackendEnum | None = None,
|
||||
) -> None:
|
||||
super().__init__()
|
||||
|
||||
@ -772,10 +772,11 @@ class Glm4vVisionTransformer(nn.Module):
|
||||
dtype=torch.get_default_dtype(),
|
||||
attn_backend_override=attn_backend_override,
|
||||
)
|
||||
if self.attn_backend != _Backend.FLASH_ATTN and check_upstream_fa_availability(
|
||||
torch.get_default_dtype()
|
||||
if (
|
||||
self.attn_backend != AttentionBackendEnum.FLASH_ATTN
|
||||
and check_upstream_fa_availability(torch.get_default_dtype())
|
||||
):
|
||||
self.attn_backend = _Backend.FLASH_ATTN
|
||||
self.attn_backend = AttentionBackendEnum.FLASH_ATTN
|
||||
|
||||
@property
|
||||
def dtype(self) -> torch.dtype:
|
||||
@ -824,8 +825,8 @@ class Glm4vVisionTransformer(nn.Module):
|
||||
max_seqlen, seqlens = None, None
|
||||
seqlens = (cu_seqlens[1:] - cu_seqlens[:-1]).tolist()
|
||||
if (
|
||||
self.attn_backend == _Backend.FLASH_ATTN
|
||||
or self.attn_backend == _Backend.ROCM_AITER_FA
|
||||
self.attn_backend == AttentionBackendEnum.FLASH_ATTN
|
||||
or self.attn_backend == AttentionBackendEnum.ROCM_AITER_FA
|
||||
):
|
||||
max_seqlen = (cu_seqlens[1:] - cu_seqlens[:-1]).max().item()
|
||||
return max_seqlen, seqlens
|
||||
|
||||
@ -16,7 +16,7 @@ from transformers.feature_extraction_utils import BatchFeature
|
||||
from transformers.modeling_outputs import BaseModelOutput, BaseModelOutputWithPooling
|
||||
from transformers.utils import torch_int
|
||||
|
||||
from vllm.attention.backends.registry import _Backend
|
||||
from vllm.attention.backends.registry import AttentionBackendEnum
|
||||
from vllm.attention.layer import (
|
||||
maybe_get_vit_flash_attn_backend,
|
||||
)
|
||||
@ -360,7 +360,7 @@ class KeyeSiglipAttention(nn.Module):
|
||||
config: PretrainedConfig,
|
||||
quant_config: QuantizationConfig | None = None,
|
||||
prefix: str = "",
|
||||
attn_backend_override: _Backend | None = None,
|
||||
attn_backend_override: AttentionBackendEnum | None = None,
|
||||
):
|
||||
super().__init__()
|
||||
self.config = config
|
||||
@ -414,17 +414,17 @@ class KeyeSiglipAttention(nn.Module):
|
||||
)
|
||||
|
||||
if self.attn_backend not in {
|
||||
_Backend.FLASH_ATTN,
|
||||
_Backend.XFORMERS,
|
||||
_Backend.ROCM_AITER_FA,
|
||||
AttentionBackendEnum.FLASH_ATTN,
|
||||
AttentionBackendEnum.XFORMERS,
|
||||
AttentionBackendEnum.ROCM_AITER_FA,
|
||||
}:
|
||||
raise RuntimeError(
|
||||
f"Keye-VL does not support {self.attn_backend} backend now."
|
||||
)
|
||||
|
||||
self.is_flash_attn_backend = self.attn_backend in {
|
||||
_Backend.FLASH_ATTN,
|
||||
_Backend.ROCM_AITER_FA,
|
||||
AttentionBackendEnum.FLASH_ATTN,
|
||||
AttentionBackendEnum.ROCM_AITER_FA,
|
||||
}
|
||||
|
||||
def forward(
|
||||
@ -489,7 +489,7 @@ class KeyeSiglipAttention(nn.Module):
|
||||
softmax_scale=self.scale,
|
||||
)
|
||||
context_layer = rearrange(output, "(b s) ... -> b s ...", b=batch_size)
|
||||
elif self.attn_backend == _Backend.XFORMERS:
|
||||
elif self.attn_backend == AttentionBackendEnum.XFORMERS:
|
||||
from xformers import ops as xops
|
||||
from xformers.ops.fmha.attn_bias import BlockDiagonalMask
|
||||
|
||||
@ -536,7 +536,7 @@ class KeyeSiglipEncoderLayer(nn.Module):
|
||||
config: PretrainedConfig,
|
||||
quant_config: QuantizationConfig | None = None,
|
||||
prefix: str = "",
|
||||
attn_backend_override: _Backend | None = None,
|
||||
attn_backend_override: AttentionBackendEnum | None = None,
|
||||
):
|
||||
super().__init__()
|
||||
self.embed_dim = config.hidden_size
|
||||
@ -590,7 +590,7 @@ class KeyeSiglipEncoder(nn.Module):
|
||||
config: PretrainedConfig,
|
||||
quant_config: QuantizationConfig | None = None,
|
||||
prefix: str = "",
|
||||
attn_backend_override: _Backend | None = None,
|
||||
attn_backend_override: AttentionBackendEnum | None = None,
|
||||
):
|
||||
super().__init__()
|
||||
self.config = config
|
||||
@ -685,7 +685,7 @@ class KeyeSiglipVisionTransformer(nn.Module):
|
||||
config: PretrainedConfig,
|
||||
quant_config: QuantizationConfig | None = None,
|
||||
prefix: str = "",
|
||||
attn_backend_override: _Backend | None = None,
|
||||
attn_backend_override: AttentionBackendEnum | None = None,
|
||||
):
|
||||
super().__init__()
|
||||
self.config = config
|
||||
@ -768,7 +768,7 @@ class KeyeSiglipVisionModel(nn.Module):
|
||||
config: PretrainedConfig,
|
||||
quant_config: QuantizationConfig | None = None,
|
||||
prefix: str = "",
|
||||
attn_backend_override: _Backend | None = None,
|
||||
attn_backend_override: AttentionBackendEnum | None = None,
|
||||
):
|
||||
super().__init__()
|
||||
|
||||
|
||||
@ -10,7 +10,7 @@ import torch
|
||||
import torch.nn as nn
|
||||
from transformers import BaseImageProcessor, BatchFeature, PretrainedConfig
|
||||
|
||||
from vllm.attention.backends.registry import _Backend
|
||||
from vllm.attention.backends.registry import AttentionBackendEnum
|
||||
from vllm.config import VllmConfig
|
||||
from vllm.config.multimodal import BaseDummyOptions
|
||||
from vllm.model_executor.layers.linear import ReplicatedLinear
|
||||
@ -106,7 +106,7 @@ class VisualTokenizer(torch.nn.Module):
|
||||
quant_config: QuantizationConfig | None = None,
|
||||
prefix: str = "",
|
||||
use_data_parallel: bool = False,
|
||||
attn_backend_override: _Backend | None = None,
|
||||
attn_backend_override: AttentionBackendEnum | None = None,
|
||||
):
|
||||
super().__init__()
|
||||
self.config = config
|
||||
@ -135,7 +135,7 @@ class VisualTokenizer(torch.nn.Module):
|
||||
quant_config: QuantizationConfig | None = None,
|
||||
prefix: str = "",
|
||||
use_data_parallel: bool = False,
|
||||
attn_backend_override: _Backend | None = None,
|
||||
attn_backend_override: AttentionBackendEnum | None = None,
|
||||
):
|
||||
model_type = config.model_type
|
||||
if model_type == "siglip2_navit":
|
||||
|
||||
@ -31,7 +31,7 @@ from transformers.modeling_outputs import (
|
||||
)
|
||||
from transformers.utils import torch_int
|
||||
|
||||
from vllm.attention.backends.registry import _Backend
|
||||
from vllm.attention.backends.registry import AttentionBackendEnum
|
||||
from vllm.attention.layer import (
|
||||
check_upstream_fa_availability,
|
||||
maybe_get_vit_flash_attn_backend,
|
||||
@ -580,8 +580,8 @@ class SiglipAttention(nn.Module):
|
||||
projection_size: int,
|
||||
quant_config: QuantizationConfig | None = None,
|
||||
prefix: str = "",
|
||||
attn_backend: _Backend = _Backend.TORCH_SDPA,
|
||||
attn_backend_override: _Backend | None = None,
|
||||
attn_backend: AttentionBackendEnum = AttentionBackendEnum.TORCH_SDPA,
|
||||
attn_backend_override: AttentionBackendEnum | None = None,
|
||||
use_upstream_fa: bool = False,
|
||||
) -> None:
|
||||
super().__init__()
|
||||
@ -621,8 +621,8 @@ class SiglipAttention(nn.Module):
|
||||
)
|
||||
)
|
||||
self.is_flash_attn_backend = self.attn_backend in {
|
||||
_Backend.FLASH_ATTN,
|
||||
_Backend.ROCM_AITER_FA,
|
||||
AttentionBackendEnum.FLASH_ATTN,
|
||||
AttentionBackendEnum.ROCM_AITER_FA,
|
||||
}
|
||||
|
||||
def split_qkv(self, qkv: torch.Tensor) -> tuple[torch.Tensor, ...]:
|
||||
@ -680,10 +680,10 @@ class SiglipAttention(nn.Module):
|
||||
cu_seqlens,
|
||||
max_seqlen,
|
||||
batch_size,
|
||||
self.attn_backend == _Backend.ROCM_AITER_FA,
|
||||
self.attn_backend == AttentionBackendEnum.ROCM_AITER_FA,
|
||||
self.use_upstream_fa,
|
||||
)
|
||||
elif self.attn_backend == _Backend.TORCH_SDPA:
|
||||
elif self.attn_backend == AttentionBackendEnum.TORCH_SDPA:
|
||||
outputs = []
|
||||
for i in range(1, len(cu_seqlens)):
|
||||
start_idx = cu_seqlens[i - 1]
|
||||
@ -702,7 +702,7 @@ class SiglipAttention(nn.Module):
|
||||
context_layer = rearrange(
|
||||
context_layer, "b s h d -> s b (h d)"
|
||||
).contiguous()
|
||||
elif self.attn_backend == _Backend.XFORMERS:
|
||||
elif self.attn_backend == AttentionBackendEnum.XFORMERS:
|
||||
if seqlens is None:
|
||||
raise ValueError("xFormers attention backend requires seqlens tensor.")
|
||||
context_layer = vit_xformers_attn_wrapper(q, k, v, seqlens)
|
||||
@ -786,8 +786,8 @@ class SiglipEncoderLayer(nn.Module):
|
||||
quant_config: QuantizationConfig | None = None,
|
||||
prefix: str = "",
|
||||
*,
|
||||
attn_backend: _Backend = _Backend.TORCH_SDPA,
|
||||
attn_backend_override: _Backend | None = None,
|
||||
attn_backend: AttentionBackendEnum = AttentionBackendEnum.TORCH_SDPA,
|
||||
attn_backend_override: AttentionBackendEnum | None = None,
|
||||
use_upstream_fa: bool = False,
|
||||
):
|
||||
super().__init__()
|
||||
@ -847,7 +847,7 @@ class SiglipEncoder(nn.Module):
|
||||
config: PretrainedConfig,
|
||||
quant_config: QuantizationConfig | None = None,
|
||||
prefix: str = "",
|
||||
attn_backend_override: _Backend | None = None,
|
||||
attn_backend_override: AttentionBackendEnum | None = None,
|
||||
):
|
||||
super().__init__()
|
||||
self.config = config
|
||||
@ -861,16 +861,16 @@ class SiglipEncoder(nn.Module):
|
||||
)
|
||||
self.use_upstream_fa = False
|
||||
if self.attn_backend not in {
|
||||
_Backend.FLASH_ATTN,
|
||||
_Backend.ROCM_AITER_FA,
|
||||
AttentionBackendEnum.FLASH_ATTN,
|
||||
AttentionBackendEnum.ROCM_AITER_FA,
|
||||
} and check_upstream_fa_availability(torch.get_default_dtype()):
|
||||
self.attn_backend = _Backend.FLASH_ATTN
|
||||
self.attn_backend = AttentionBackendEnum.FLASH_ATTN
|
||||
self.use_upstream_fa = True
|
||||
if self.attn_backend not in {
|
||||
_Backend.FLASH_ATTN,
|
||||
_Backend.TORCH_SDPA,
|
||||
_Backend.XFORMERS,
|
||||
_Backend.ROCM_AITER_FA,
|
||||
AttentionBackendEnum.FLASH_ATTN,
|
||||
AttentionBackendEnum.TORCH_SDPA,
|
||||
AttentionBackendEnum.XFORMERS,
|
||||
AttentionBackendEnum.ROCM_AITER_FA,
|
||||
}:
|
||||
raise RuntimeError(
|
||||
f"PaddleOCR-VL does not support {self.attn_backend} backend now."
|
||||
@ -943,9 +943,12 @@ class SiglipEncoder(nn.Module):
|
||||
|
||||
max_seqlen = None
|
||||
seqlens = None
|
||||
if self.attn_backend in {_Backend.FLASH_ATTN, _Backend.ROCM_AITER_FA}:
|
||||
if self.attn_backend in {
|
||||
AttentionBackendEnum.FLASH_ATTN,
|
||||
AttentionBackendEnum.ROCM_AITER_FA,
|
||||
}:
|
||||
max_seqlen = (cu_seqlens[1:] - cu_seqlens[:-1]).max()
|
||||
elif self.attn_backend == _Backend.XFORMERS:
|
||||
elif self.attn_backend == AttentionBackendEnum.XFORMERS:
|
||||
seqlens = cu_seqlens[1:] - cu_seqlens[:-1]
|
||||
|
||||
hidden_states = inputs_embeds
|
||||
@ -966,7 +969,7 @@ class SiglipVisionTransformer(nn.Module):
|
||||
config: PretrainedConfig,
|
||||
quant_config: QuantizationConfig | None = None,
|
||||
prefix: str = "",
|
||||
attn_backend_override: _Backend | None = None,
|
||||
attn_backend_override: AttentionBackendEnum | None = None,
|
||||
):
|
||||
super().__init__()
|
||||
self.config = config
|
||||
@ -1016,7 +1019,7 @@ class SiglipVisionModel(nn.Module):
|
||||
config,
|
||||
quant_config: QuantizationConfig | None = None,
|
||||
prefix: str = "",
|
||||
attn_backend_override: _Backend | None = None,
|
||||
attn_backend_override: AttentionBackendEnum | None = None,
|
||||
):
|
||||
super().__init__()
|
||||
|
||||
|
||||
@ -42,7 +42,7 @@ from transformers.models.qwen2_5_vl.configuration_qwen2_5_vl import (
|
||||
Qwen2_5_VLVisionConfig,
|
||||
)
|
||||
|
||||
from vllm.attention.backends.registry import _Backend
|
||||
from vllm.attention.backends.registry import AttentionBackendEnum
|
||||
from vllm.attention.layer import maybe_get_vit_flash_attn_backend
|
||||
from vllm.attention.ops.vit_attn_wrappers import (
|
||||
vit_flash_attn_wrapper,
|
||||
@ -315,9 +315,9 @@ class Qwen2_5_VisionAttention(nn.Module):
|
||||
quant_config: QuantizationConfig | None = None,
|
||||
prefix: str = "",
|
||||
use_data_parallel: bool = False,
|
||||
attn_backend: _Backend = _Backend.TORCH_SDPA,
|
||||
attn_backend: AttentionBackendEnum = AttentionBackendEnum.TORCH_SDPA,
|
||||
use_upstream_fa: bool = False,
|
||||
attn_backend_override: _Backend | None = None,
|
||||
attn_backend_override: AttentionBackendEnum | None = None,
|
||||
) -> None:
|
||||
super().__init__()
|
||||
# Per attention head and per partition values.
|
||||
@ -364,13 +364,16 @@ class Qwen2_5_VisionAttention(nn.Module):
|
||||
# On ROCm with FLASH_ATTN backend, upstream flash_attn is used
|
||||
from vllm.platforms import current_platform
|
||||
|
||||
if current_platform.is_rocm() and self.attn_backend == _Backend.FLASH_ATTN:
|
||||
if (
|
||||
current_platform.is_rocm()
|
||||
and self.attn_backend == AttentionBackendEnum.FLASH_ATTN
|
||||
):
|
||||
self.use_upstream_fa = True
|
||||
if current_platform.is_xpu():
|
||||
self.use_upstream_fa = False
|
||||
self.is_flash_attn_backend = self.attn_backend in {
|
||||
_Backend.FLASH_ATTN,
|
||||
_Backend.ROCM_AITER_FA,
|
||||
AttentionBackendEnum.FLASH_ATTN,
|
||||
AttentionBackendEnum.ROCM_AITER_FA,
|
||||
}
|
||||
|
||||
def split_qkv(self, qkv: torch.Tensor) -> tuple[torch.Tensor, ...]:
|
||||
@ -431,10 +434,10 @@ class Qwen2_5_VisionAttention(nn.Module):
|
||||
cu_seqlens,
|
||||
max_seqlen,
|
||||
batch_size,
|
||||
self.attn_backend == _Backend.ROCM_AITER_FA,
|
||||
self.attn_backend == AttentionBackendEnum.ROCM_AITER_FA,
|
||||
self.use_upstream_fa,
|
||||
)
|
||||
elif self.attn_backend == _Backend.TORCH_SDPA:
|
||||
elif self.attn_backend == AttentionBackendEnum.TORCH_SDPA:
|
||||
# Execute attention entry by entry for speed & less VRAM.
|
||||
from vllm.platforms import current_platform
|
||||
|
||||
@ -450,7 +453,7 @@ class Qwen2_5_VisionAttention(nn.Module):
|
||||
v,
|
||||
cu_seqlens,
|
||||
)
|
||||
elif self.attn_backend == _Backend.XFORMERS:
|
||||
elif self.attn_backend == AttentionBackendEnum.XFORMERS:
|
||||
context_layer = vit_xformers_attn_wrapper(q, k, v, seqlens)
|
||||
|
||||
output, _ = self.proj(context_layer)
|
||||
@ -478,9 +481,9 @@ class Qwen2_5_VisionBlock(nn.Module):
|
||||
quant_config: QuantizationConfig | None = None,
|
||||
prefix: str = "",
|
||||
use_data_parallel: bool = False,
|
||||
attn_backend: _Backend = _Backend.TORCH_SDPA,
|
||||
attn_backend: AttentionBackendEnum = AttentionBackendEnum.TORCH_SDPA,
|
||||
use_upstream_fa: bool = False,
|
||||
attn_backend_override: _Backend | None = None,
|
||||
attn_backend_override: AttentionBackendEnum | None = None,
|
||||
) -> None:
|
||||
super().__init__()
|
||||
if norm_layer is None:
|
||||
@ -656,7 +659,7 @@ class Qwen2_5_VisionTransformer(nn.Module):
|
||||
quant_config: QuantizationConfig | None = None,
|
||||
prefix: str = "",
|
||||
use_data_parallel: bool = False,
|
||||
attn_backend_override: _Backend | None = None,
|
||||
attn_backend_override: AttentionBackendEnum | None = None,
|
||||
) -> None:
|
||||
super().__init__()
|
||||
|
||||
@ -708,10 +711,10 @@ class Qwen2_5_VisionTransformer(nn.Module):
|
||||
)
|
||||
|
||||
if self.attn_backend not in {
|
||||
_Backend.FLASH_ATTN,
|
||||
_Backend.TORCH_SDPA,
|
||||
_Backend.XFORMERS,
|
||||
_Backend.ROCM_AITER_FA,
|
||||
AttentionBackendEnum.FLASH_ATTN,
|
||||
AttentionBackendEnum.TORCH_SDPA,
|
||||
AttentionBackendEnum.XFORMERS,
|
||||
AttentionBackendEnum.ROCM_AITER_FA,
|
||||
}:
|
||||
raise RuntimeError(
|
||||
f"Qwen2.5-VL does not support {self.attn_backend} backend now."
|
||||
@ -850,9 +853,12 @@ class Qwen2_5_VisionTransformer(nn.Module):
|
||||
) -> tuple[torch.Tensor, torch.Tensor]:
|
||||
max_seqlen = torch.zeros([], device=cu_seqlens.device)
|
||||
seqlens = torch.zeros(1, device=cu_seqlens.device)
|
||||
if self.attn_backend in {_Backend.FLASH_ATTN, _Backend.ROCM_AITER_FA}:
|
||||
if self.attn_backend in {
|
||||
AttentionBackendEnum.FLASH_ATTN,
|
||||
AttentionBackendEnum.ROCM_AITER_FA,
|
||||
}:
|
||||
max_seqlen = (cu_seqlens[1:] - cu_seqlens[:-1]).max()
|
||||
elif self.attn_backend == _Backend.XFORMERS:
|
||||
elif self.attn_backend == AttentionBackendEnum.XFORMERS:
|
||||
seqlens = cu_seqlens[1:] - cu_seqlens[:-1]
|
||||
return max_seqlen, seqlens
|
||||
|
||||
|
||||
@ -43,7 +43,7 @@ from transformers.models.qwen2_vl.configuration_qwen2_vl import (
|
||||
from transformers.models.qwen2_vl.image_processing_qwen2_vl import smart_resize
|
||||
from transformers.models.qwen2_vl.video_processing_qwen2_vl import Qwen2VLVideoProcessor
|
||||
|
||||
from vllm.attention.backends.registry import _Backend
|
||||
from vllm.attention.backends.registry import AttentionBackendEnum
|
||||
from vllm.attention.layer import (
|
||||
check_upstream_fa_availability,
|
||||
maybe_get_vit_flash_attn_backend,
|
||||
@ -329,7 +329,7 @@ class Qwen2VisionAttention(nn.Module):
|
||||
quant_config: QuantizationConfig | None = None,
|
||||
prefix: str = "",
|
||||
use_data_parallel: bool = False,
|
||||
attn_backend_override: _Backend | None = None,
|
||||
attn_backend_override: AttentionBackendEnum | None = None,
|
||||
) -> None:
|
||||
super().__init__()
|
||||
# Per attention head and per partition values.
|
||||
@ -378,18 +378,18 @@ class Qwen2VisionAttention(nn.Module):
|
||||
)
|
||||
|
||||
if self.attn_backend not in {
|
||||
_Backend.FLASH_ATTN,
|
||||
_Backend.TORCH_SDPA,
|
||||
_Backend.XFORMERS,
|
||||
_Backend.ROCM_AITER_FA,
|
||||
AttentionBackendEnum.FLASH_ATTN,
|
||||
AttentionBackendEnum.TORCH_SDPA,
|
||||
AttentionBackendEnum.XFORMERS,
|
||||
AttentionBackendEnum.ROCM_AITER_FA,
|
||||
}:
|
||||
raise RuntimeError(
|
||||
f"Qwen2-VL does not support {self.attn_backend} backend now."
|
||||
)
|
||||
|
||||
self.is_flash_attn_backend = self.attn_backend in {
|
||||
_Backend.FLASH_ATTN,
|
||||
_Backend.ROCM_AITER_FA,
|
||||
AttentionBackendEnum.FLASH_ATTN,
|
||||
AttentionBackendEnum.ROCM_AITER_FA,
|
||||
}
|
||||
|
||||
def split_qkv(self, qkv: torch.Tensor) -> tuple[torch.Tensor, ...]:
|
||||
@ -460,7 +460,7 @@ class Qwen2VisionAttention(nn.Module):
|
||||
context_layer = rearrange(
|
||||
output, "(b s) h d -> s b (h d)", b=batch_size
|
||||
).contiguous()
|
||||
elif self.attn_backend == _Backend.TORCH_SDPA:
|
||||
elif self.attn_backend == AttentionBackendEnum.TORCH_SDPA:
|
||||
# Execute attention entry by entry for speed & less VRAM.
|
||||
from vllm.platforms import current_platform
|
||||
|
||||
@ -485,7 +485,7 @@ class Qwen2VisionAttention(nn.Module):
|
||||
context_layer = rearrange(
|
||||
context_layer, "b s h d -> s b (h d)"
|
||||
).contiguous()
|
||||
elif self.attn_backend == _Backend.XFORMERS:
|
||||
elif self.attn_backend == AttentionBackendEnum.XFORMERS:
|
||||
from xformers import ops as xops
|
||||
from xformers.ops.fmha.attn_bias import BlockDiagonalMask
|
||||
|
||||
@ -515,7 +515,7 @@ class Qwen2VisionBlock(nn.Module):
|
||||
quant_config: QuantizationConfig | None = None,
|
||||
prefix: str = "",
|
||||
use_data_parallel: bool = False,
|
||||
attn_backend_override: _Backend | None = None,
|
||||
attn_backend_override: AttentionBackendEnum | None = None,
|
||||
) -> None:
|
||||
super().__init__()
|
||||
if norm_layer is None:
|
||||
@ -679,7 +679,7 @@ class Qwen2VisionTransformer(nn.Module):
|
||||
quant_config: QuantizationConfig | None = None,
|
||||
prefix: str = "",
|
||||
use_data_parallel: bool = False,
|
||||
attn_backend_override: _Backend | None = None,
|
||||
attn_backend_override: AttentionBackendEnum | None = None,
|
||||
) -> None:
|
||||
super().__init__()
|
||||
|
||||
@ -739,10 +739,11 @@ class Qwen2VisionTransformer(nn.Module):
|
||||
dtype=torch.get_default_dtype(),
|
||||
attn_backend_override=attn_backend_override,
|
||||
)
|
||||
if self.attn_backend != _Backend.FLASH_ATTN and check_upstream_fa_availability(
|
||||
torch.get_default_dtype()
|
||||
if (
|
||||
self.attn_backend != AttentionBackendEnum.FLASH_ATTN
|
||||
and check_upstream_fa_availability(torch.get_default_dtype())
|
||||
):
|
||||
self.attn_backend = _Backend.FLASH_ATTN
|
||||
self.attn_backend = AttentionBackendEnum.FLASH_ATTN
|
||||
|
||||
@property
|
||||
def dtype(self) -> torch.dtype:
|
||||
@ -789,9 +790,12 @@ class Qwen2VisionTransformer(nn.Module):
|
||||
self, cu_seqlens: torch.Tensor
|
||||
) -> tuple[int | None, list[int] | None]:
|
||||
max_seqlen, seqlens = None, None
|
||||
if self.attn_backend in {_Backend.FLASH_ATTN, _Backend.ROCM_AITER_FA}:
|
||||
if self.attn_backend in {
|
||||
AttentionBackendEnum.FLASH_ATTN,
|
||||
AttentionBackendEnum.ROCM_AITER_FA,
|
||||
}:
|
||||
max_seqlen = (cu_seqlens[1:] - cu_seqlens[:-1]).max().item()
|
||||
elif self.attn_backend == _Backend.XFORMERS:
|
||||
elif self.attn_backend == AttentionBackendEnum.XFORMERS:
|
||||
seqlens = (cu_seqlens[1:] - cu_seqlens[:-1]).tolist()
|
||||
return max_seqlen, seqlens
|
||||
|
||||
|
||||
@ -47,7 +47,7 @@ from transformers.models.qwen3_omni_moe.processing_qwen3_omni_moe import (
|
||||
)
|
||||
from transformers.models.whisper import WhisperFeatureExtractor
|
||||
|
||||
from vllm.attention.backends.registry import _Backend
|
||||
from vllm.attention.backends.registry import AttentionBackendEnum
|
||||
from vllm.attention.layer import check_upstream_fa_availability
|
||||
from vllm.compilation.decorators import support_torch_compile
|
||||
from vllm.config import VllmConfig
|
||||
@ -301,7 +301,7 @@ class Qwen3Omni_VisionTransformer(nn.Module):
|
||||
norm_eps: float = 1e-6,
|
||||
quant_config: QuantizationConfig | None = None,
|
||||
prefix: str = "",
|
||||
attn_backend_override: _Backend | None = None,
|
||||
attn_backend_override: AttentionBackendEnum | None = None,
|
||||
) -> None:
|
||||
super().__init__()
|
||||
self.hidden_size = vision_config.hidden_size
|
||||
@ -377,10 +377,11 @@ class Qwen3Omni_VisionTransformer(nn.Module):
|
||||
dtype=torch.get_default_dtype(),
|
||||
attn_backend_override=attn_backend_override,
|
||||
)
|
||||
if self.attn_backend != _Backend.FLASH_ATTN and check_upstream_fa_availability(
|
||||
torch.get_default_dtype()
|
||||
if (
|
||||
self.attn_backend != AttentionBackendEnum.FLASH_ATTN
|
||||
and check_upstream_fa_availability(torch.get_default_dtype())
|
||||
):
|
||||
self.attn_backend = _Backend.FLASH_ATTN
|
||||
self.attn_backend = AttentionBackendEnum.FLASH_ATTN
|
||||
|
||||
@property
|
||||
def dtype(self) -> torch.dtype:
|
||||
@ -490,9 +491,9 @@ class Qwen3Omni_VisionTransformer(nn.Module):
|
||||
) -> tuple[torch.Tensor, torch.Tensor]:
|
||||
max_seqlen = torch.zeros([], device=cu_seqlens.device)
|
||||
seqlens = torch.zeros(1, device=cu_seqlens.device)
|
||||
if self.attn_backend == _Backend.FLASH_ATTN:
|
||||
if self.attn_backend == AttentionBackendEnum.FLASH_ATTN:
|
||||
max_seqlen = (cu_seqlens[1:] - cu_seqlens[:-1]).max()
|
||||
elif self.attn_backend == _Backend.XFORMERS:
|
||||
elif self.attn_backend == AttentionBackendEnum.XFORMERS:
|
||||
seqlens = cu_seqlens[1:] - cu_seqlens[:-1]
|
||||
return max_seqlen, seqlens
|
||||
|
||||
|
||||
@ -49,7 +49,7 @@ from transformers.models.qwen3_vl.video_processing_qwen3_vl import (
|
||||
)
|
||||
from transformers.video_utils import VideoMetadata
|
||||
|
||||
from vllm.attention.backends.registry import _Backend
|
||||
from vllm.attention.backends.registry import AttentionBackendEnum
|
||||
from vllm.attention.layer import check_upstream_fa_availability
|
||||
from vllm.compilation.decorators import support_torch_compile
|
||||
from vllm.config import VllmConfig
|
||||
@ -198,7 +198,7 @@ class Qwen3_VisionBlock(nn.Module):
|
||||
quant_config: QuantizationConfig | None = None,
|
||||
prefix: str = "",
|
||||
use_data_parallel: bool = False,
|
||||
attn_backend: _Backend = _Backend.TORCH_SDPA,
|
||||
attn_backend: AttentionBackendEnum = AttentionBackendEnum.TORCH_SDPA,
|
||||
use_upstream_fa: bool = False,
|
||||
) -> None:
|
||||
super().__init__()
|
||||
@ -306,7 +306,7 @@ class Qwen3_VisionTransformer(nn.Module):
|
||||
quant_config: QuantizationConfig | None = None,
|
||||
prefix: str = "",
|
||||
use_data_parallel: bool = False,
|
||||
attn_backend_override: _Backend | None = None,
|
||||
attn_backend_override: AttentionBackendEnum | None = None,
|
||||
) -> None:
|
||||
super().__init__()
|
||||
self.hidden_size = vision_config.hidden_size
|
||||
@ -372,18 +372,18 @@ class Qwen3_VisionTransformer(nn.Module):
|
||||
)
|
||||
use_upstream_fa = False
|
||||
if (
|
||||
self.attn_backend != _Backend.FLASH_ATTN
|
||||
and self.attn_backend != _Backend.ROCM_AITER_FA
|
||||
self.attn_backend != AttentionBackendEnum.FLASH_ATTN
|
||||
and self.attn_backend != AttentionBackendEnum.ROCM_AITER_FA
|
||||
and check_upstream_fa_availability(torch.get_default_dtype())
|
||||
):
|
||||
self.attn_backend = _Backend.FLASH_ATTN
|
||||
self.attn_backend = AttentionBackendEnum.FLASH_ATTN
|
||||
use_upstream_fa = True
|
||||
|
||||
if self.attn_backend not in {
|
||||
_Backend.FLASH_ATTN,
|
||||
_Backend.TORCH_SDPA,
|
||||
_Backend.XFORMERS,
|
||||
_Backend.ROCM_AITER_FA,
|
||||
AttentionBackendEnum.FLASH_ATTN,
|
||||
AttentionBackendEnum.TORCH_SDPA,
|
||||
AttentionBackendEnum.XFORMERS,
|
||||
AttentionBackendEnum.ROCM_AITER_FA,
|
||||
}:
|
||||
raise RuntimeError(
|
||||
f"Qwen3-VL does not support {self.attn_backend} backend now."
|
||||
@ -510,11 +510,11 @@ class Qwen3_VisionTransformer(nn.Module):
|
||||
max_seqlen = torch.zeros([], device=cu_seqlens.device)
|
||||
seqlens = torch.zeros(1, device=cu_seqlens.device)
|
||||
if (
|
||||
self.attn_backend == _Backend.FLASH_ATTN
|
||||
or self.attn_backend == _Backend.ROCM_AITER_FA
|
||||
self.attn_backend == AttentionBackendEnum.FLASH_ATTN
|
||||
or self.attn_backend == AttentionBackendEnum.ROCM_AITER_FA
|
||||
):
|
||||
max_seqlen = (cu_seqlens[1:] - cu_seqlens[:-1]).max()
|
||||
elif self.attn_backend == _Backend.XFORMERS:
|
||||
elif self.attn_backend == AttentionBackendEnum.XFORMERS:
|
||||
seqlens = cu_seqlens[1:] - cu_seqlens[:-1]
|
||||
return max_seqlen, seqlens
|
||||
|
||||
|
||||
@ -12,7 +12,7 @@ from torch.nn import functional as F
|
||||
from transformers import Siglip2VisionConfig
|
||||
from transformers.configuration_utils import PretrainedConfig
|
||||
|
||||
from vllm.attention.backends.registry import _Backend
|
||||
from vllm.attention.backends.registry import AttentionBackendEnum
|
||||
from vllm.attention.layer import maybe_get_vit_flash_attn_backend
|
||||
from vllm.distributed import divide, get_tensor_model_parallel_world_size
|
||||
from vllm.model_executor.layers.activation import get_act_fn
|
||||
@ -208,7 +208,7 @@ class Siglip2Attention(nn.Module):
|
||||
quant_config: QuantizationConfig | None = None,
|
||||
prefix: str = "",
|
||||
use_data_parallel: bool = False,
|
||||
attn_backend_override: _Backend | None = None,
|
||||
attn_backend_override: AttentionBackendEnum | None = None,
|
||||
):
|
||||
super().__init__()
|
||||
self.config = config
|
||||
@ -264,14 +264,14 @@ class Siglip2Attention(nn.Module):
|
||||
)
|
||||
|
||||
if self.attn_backend not in {
|
||||
_Backend.FLASH_ATTN,
|
||||
_Backend.TORCH_SDPA,
|
||||
_Backend.ROCM_AITER_FA,
|
||||
AttentionBackendEnum.FLASH_ATTN,
|
||||
AttentionBackendEnum.TORCH_SDPA,
|
||||
AttentionBackendEnum.ROCM_AITER_FA,
|
||||
}:
|
||||
self.attn_backend = _Backend.TORCH_SDPA
|
||||
self.attn_backend = AttentionBackendEnum.TORCH_SDPA
|
||||
self.is_flash_attn_backend = self.attn_backend in {
|
||||
_Backend.FLASH_ATTN,
|
||||
_Backend.ROCM_AITER_FA,
|
||||
AttentionBackendEnum.FLASH_ATTN,
|
||||
AttentionBackendEnum.ROCM_AITER_FA,
|
||||
}
|
||||
|
||||
def forward(
|
||||
@ -308,7 +308,7 @@ class Siglip2Attention(nn.Module):
|
||||
attn_output = self.flash_attn_varlen_func(
|
||||
queries, keys, values, cu_seqlens, cu_seqlens, max_seqlen, max_seqlen
|
||||
).reshape(seq_length, -1)
|
||||
elif self.attn_backend == _Backend.TORCH_SDPA:
|
||||
elif self.attn_backend == AttentionBackendEnum.TORCH_SDPA:
|
||||
# Execute attention entry by entry for speed & less VRAM.
|
||||
batch_size = cu_seqlens.shape[0] - 1
|
||||
outputs = []
|
||||
@ -376,7 +376,7 @@ class Siglip2EncoderLayer(nn.Module):
|
||||
quant_config: QuantizationConfig | None = None,
|
||||
prefix: str = "",
|
||||
use_data_parallel: bool = False,
|
||||
attn_backend_override: _Backend | None = None,
|
||||
attn_backend_override: AttentionBackendEnum | None = None,
|
||||
):
|
||||
super().__init__()
|
||||
self.embed_dim = config.hidden_size
|
||||
@ -440,7 +440,7 @@ class Siglip2Encoder(nn.Module):
|
||||
quant_config: QuantizationConfig | None = None,
|
||||
prefix: str = "",
|
||||
use_data_parallel: bool = False,
|
||||
attn_backend_override: _Backend | None = None,
|
||||
attn_backend_override: AttentionBackendEnum | None = None,
|
||||
):
|
||||
super().__init__()
|
||||
self.config = config
|
||||
@ -626,7 +626,7 @@ class Siglip2VisionTransformer(nn.Module):
|
||||
quant_config: QuantizationConfig | None = None,
|
||||
prefix: str = "",
|
||||
use_data_parallel: bool = False,
|
||||
attn_backend_override: _Backend | None = None,
|
||||
attn_backend_override: AttentionBackendEnum | None = None,
|
||||
):
|
||||
super().__init__()
|
||||
self.config = config
|
||||
@ -667,7 +667,7 @@ class Siglip2NavitModel(torch.nn.Module):
|
||||
quant_config: QuantizationConfig | None = None,
|
||||
prefix: str = "",
|
||||
use_data_parallel: bool = False,
|
||||
attn_backend_override: _Backend | None = None,
|
||||
attn_backend_override: AttentionBackendEnum | None = None,
|
||||
):
|
||||
super().__init__()
|
||||
|
||||
|
||||
@ -10,7 +10,7 @@ from typing import Final, Generic, Literal, Protocol, TypeAlias, TypeVar
|
||||
import torch
|
||||
from transformers import PretrainedConfig
|
||||
|
||||
from vllm.attention.backends.registry import _Backend
|
||||
from vllm.attention.backends.registry import AttentionBackendEnum
|
||||
from vllm.config import VllmConfig
|
||||
from vllm.distributed import (
|
||||
get_tensor_model_parallel_rank,
|
||||
@ -83,8 +83,8 @@ def get_vit_attn_backend(
|
||||
head_size: int,
|
||||
dtype: torch.dtype,
|
||||
*,
|
||||
attn_backend_override: _Backend | None = None,
|
||||
) -> _Backend:
|
||||
attn_backend_override: AttentionBackendEnum | None = None,
|
||||
) -> AttentionBackendEnum:
|
||||
"""
|
||||
Get the available attention backend for Vision Transformer.
|
||||
"""
|
||||
@ -94,7 +94,7 @@ def get_vit_attn_backend(
|
||||
# Lazy import to avoid circular dependency
|
||||
from vllm.attention.selector import get_env_variable_attn_backend
|
||||
|
||||
selected_backend: _Backend | None = get_env_variable_attn_backend()
|
||||
selected_backend: AttentionBackendEnum | None = get_env_variable_attn_backend()
|
||||
if selected_backend is not None:
|
||||
return selected_backend
|
||||
|
||||
|
||||
@ -23,10 +23,10 @@ from .interface import CpuArchEnum, Platform, PlatformEnum
|
||||
logger = init_logger(__name__)
|
||||
|
||||
if TYPE_CHECKING:
|
||||
from vllm.attention.backends.registry import _Backend
|
||||
from vllm.attention.backends.registry import AttentionBackendEnum
|
||||
from vllm.config import VllmConfig
|
||||
else:
|
||||
_Backend = None
|
||||
AttentionBackendEnum = None
|
||||
VllmConfig = None
|
||||
|
||||
|
||||
@ -127,7 +127,7 @@ class CpuPlatform(Platform):
|
||||
@classmethod
|
||||
def get_attn_backend_cls(
|
||||
cls,
|
||||
selected_backend: "_Backend",
|
||||
selected_backend: "AttentionBackendEnum",
|
||||
head_size: int,
|
||||
dtype: torch.dtype,
|
||||
kv_cache_dtype: str | None,
|
||||
@ -137,9 +137,9 @@ class CpuPlatform(Platform):
|
||||
has_sink: bool,
|
||||
use_sparse: bool,
|
||||
) -> str:
|
||||
from vllm.attention.backends.registry import _Backend
|
||||
from vllm.attention.backends.registry import AttentionBackendEnum
|
||||
|
||||
if selected_backend and selected_backend != _Backend.TORCH_SDPA:
|
||||
if selected_backend and selected_backend != AttentionBackendEnum.TORCH_SDPA:
|
||||
logger.info("Cannot use %s backend on CPU.", selected_backend)
|
||||
if use_mla:
|
||||
raise NotImplementedError("MLA is not supported on CPU.")
|
||||
@ -148,7 +148,7 @@ class CpuPlatform(Platform):
|
||||
logger.info("Using Torch SDPA backend.")
|
||||
if not use_v1:
|
||||
raise ValueError("CPU backend only supports V1.")
|
||||
return "vllm.v1.attention.backends.cpu_attn.TorchSDPABackend"
|
||||
return AttentionBackendEnum.TORCH_SDPA.get_path()
|
||||
|
||||
@classmethod
|
||||
def get_device_total_memory(cls, device_id: int = 0) -> int:
|
||||
|
||||
@ -22,10 +22,13 @@ from vllm.utils.torch_utils import cuda_device_count_stateless
|
||||
from .interface import DeviceCapability, Platform, PlatformEnum
|
||||
|
||||
if TYPE_CHECKING:
|
||||
from vllm.attention.backends.registry import _Backend
|
||||
from vllm.attention.backends.registry import AttentionBackendEnum
|
||||
from vllm.config import VllmConfig
|
||||
from vllm.config.cache import CacheDType
|
||||
else:
|
||||
_Backend = None
|
||||
AttentionBackendEnum = None
|
||||
VllmConfig = None
|
||||
CacheDType = None
|
||||
|
||||
logger = init_logger(__name__)
|
||||
|
||||
@ -39,6 +42,49 @@ pynvml = import_pynvml()
|
||||
torch.backends.cuda.enable_cudnn_sdp(False)
|
||||
|
||||
|
||||
@cache
|
||||
def _get_backend_priorities(
|
||||
use_mla: bool,
|
||||
device_capability: DeviceCapability,
|
||||
) -> list[AttentionBackendEnum]:
|
||||
"""Get backend priorities with lazy import to avoid circular dependency."""
|
||||
from vllm.attention.backends.registry import AttentionBackendEnum
|
||||
|
||||
if use_mla:
|
||||
if device_capability.major == 10:
|
||||
return [
|
||||
AttentionBackendEnum.CUTLASS_MLA,
|
||||
AttentionBackendEnum.FLASHINFER_MLA,
|
||||
AttentionBackendEnum.FLASHMLA,
|
||||
AttentionBackendEnum.FLASH_ATTN_MLA,
|
||||
AttentionBackendEnum.TRITON_MLA,
|
||||
AttentionBackendEnum.FLASHMLA_SPARSE,
|
||||
]
|
||||
else:
|
||||
return [
|
||||
AttentionBackendEnum.FLASHMLA,
|
||||
AttentionBackendEnum.FLASH_ATTN_MLA,
|
||||
AttentionBackendEnum.FLASHINFER_MLA,
|
||||
AttentionBackendEnum.TRITON_MLA,
|
||||
AttentionBackendEnum.FLASHMLA_SPARSE,
|
||||
]
|
||||
else:
|
||||
if device_capability.major == 10:
|
||||
return [
|
||||
AttentionBackendEnum.FLASHINFER,
|
||||
AttentionBackendEnum.FLASH_ATTN,
|
||||
AttentionBackendEnum.TRITON_ATTN,
|
||||
AttentionBackendEnum.FLEX_ATTENTION,
|
||||
]
|
||||
else:
|
||||
return [
|
||||
AttentionBackendEnum.FLASH_ATTN,
|
||||
AttentionBackendEnum.FLASHINFER,
|
||||
AttentionBackendEnum.TRITON_ATTN,
|
||||
AttentionBackendEnum.FLEX_ATTENTION,
|
||||
]
|
||||
|
||||
|
||||
def with_nvml_context(fn: Callable[_P, _R]) -> Callable[_P, _R]:
|
||||
@wraps(fn)
|
||||
def wrapper(*args: _P.args, **kwargs: _P.kwargs) -> _R:
|
||||
@ -216,217 +262,171 @@ class CudaPlatformBase(Platform):
|
||||
return torch.cuda.max_memory_allocated(device)
|
||||
|
||||
@classmethod
|
||||
def get_vit_attn_backend(cls, head_size: int, dtype: torch.dtype) -> "_Backend":
|
||||
from vllm.attention.backends.registry import _Backend
|
||||
def get_vit_attn_backend(
|
||||
cls, head_size: int, dtype: torch.dtype
|
||||
) -> "AttentionBackendEnum":
|
||||
from vllm.attention.backends.registry import AttentionBackendEnum
|
||||
|
||||
# For Blackwell GPUs, force TORCH_SDPA for now.
|
||||
# See https://github.com/facebookresearch/xformers/issues/1317#issuecomment-3199392579 # noqa: E501
|
||||
if cls.has_device_capability(100):
|
||||
return _Backend.TORCH_SDPA
|
||||
return AttentionBackendEnum.TORCH_SDPA
|
||||
|
||||
if dtype not in (torch.float16, torch.bfloat16):
|
||||
return _Backend.XFORMERS
|
||||
return AttentionBackendEnum.XFORMERS
|
||||
|
||||
if cls.has_device_capability(80):
|
||||
FLASH_ATTN_V1 = (
|
||||
"vllm.v1.attention.backends.flash_attn.FlashAttentionBackend" # noqa: E501
|
||||
)
|
||||
from vllm.attention.selector import is_attn_backend_supported
|
||||
|
||||
is_default_fa_supported = is_attn_backend_supported(
|
||||
FLASH_ATTN_V1, head_size, dtype, allow_import_error=False
|
||||
)
|
||||
if is_default_fa_supported:
|
||||
return _Backend.FLASH_ATTN
|
||||
backend_class = AttentionBackendEnum.FLASH_ATTN.get_class()
|
||||
if backend_class.supports_head_size(
|
||||
head_size
|
||||
) and backend_class.supports_dtype(dtype):
|
||||
return AttentionBackendEnum.FLASH_ATTN
|
||||
else:
|
||||
# Fallback to XFORMERS
|
||||
return _Backend.XFORMERS
|
||||
return AttentionBackendEnum.XFORMERS
|
||||
else:
|
||||
# Fallback for Volta/Turing GPUs or FA not supported
|
||||
return _Backend.XFORMERS
|
||||
return AttentionBackendEnum.XFORMERS
|
||||
|
||||
@classmethod
|
||||
def get_attn_backend_cls(
|
||||
def get_valid_backends(
|
||||
cls,
|
||||
selected_backend,
|
||||
head_size,
|
||||
dtype,
|
||||
kv_cache_dtype,
|
||||
block_size,
|
||||
use_v1,
|
||||
use_mla,
|
||||
has_sink,
|
||||
use_sparse,
|
||||
device_capability,
|
||||
) -> tuple[
|
||||
list[tuple["AttentionBackendEnum", int]],
|
||||
dict["AttentionBackendEnum", list[str]],
|
||||
]:
|
||||
valid_backends_priorities = []
|
||||
invalid_reasons = {}
|
||||
|
||||
backend_priorities = _get_backend_priorities(use_mla, device_capability)
|
||||
for priority, backend in enumerate(backend_priorities):
|
||||
try:
|
||||
backend_class = backend.get_class()
|
||||
invalid_reasons_i = backend_class.validate_configuration(
|
||||
head_size,
|
||||
dtype,
|
||||
kv_cache_dtype,
|
||||
block_size,
|
||||
use_mla,
|
||||
has_sink,
|
||||
use_sparse,
|
||||
device_capability,
|
||||
)
|
||||
except ImportError:
|
||||
invalid_reasons_i = ["ImportError"]
|
||||
if invalid_reasons_i:
|
||||
invalid_reasons[backend] = invalid_reasons_i
|
||||
else:
|
||||
valid_backends_priorities.append((backend, priority))
|
||||
|
||||
return valid_backends_priorities, invalid_reasons
|
||||
|
||||
@classmethod
|
||||
def get_attn_backend_cls(
|
||||
cls,
|
||||
selected_backend: "AttentionBackendEnum",
|
||||
head_size: int,
|
||||
dtype: torch.dtype,
|
||||
kv_cache_dtype: "CacheDType | None",
|
||||
block_size: int | None,
|
||||
use_v1: bool,
|
||||
use_mla: bool,
|
||||
has_sink: bool,
|
||||
use_sparse: bool,
|
||||
) -> str:
|
||||
from vllm.attention.backends.registry import _Backend
|
||||
if not use_v1:
|
||||
raise RuntimeError(
|
||||
"V0 attention backends have been removed. Set VLLM_USE_V1=1 "
|
||||
"to select a supported backend."
|
||||
)
|
||||
|
||||
if use_mla:
|
||||
# explicitly reject non-MLA backends when MLA is enabled to avoid
|
||||
# silently selecting an incompatible backend (e.g., FLASHINFER).
|
||||
if selected_backend in {
|
||||
_Backend.FLASHINFER,
|
||||
_Backend.FLASH_ATTN,
|
||||
_Backend.TRITON_ATTN,
|
||||
_Backend.TREE_ATTN,
|
||||
_Backend.XFORMERS,
|
||||
}:
|
||||
device_capability = cls.get_device_capability()
|
||||
assert device_capability is not None
|
||||
|
||||
# First try checking just the selected backend, if there is one.
|
||||
if selected_backend is not None:
|
||||
try:
|
||||
backend_class = selected_backend.get_class()
|
||||
invalid_reasons = backend_class.validate_configuration(
|
||||
head_size,
|
||||
dtype,
|
||||
kv_cache_dtype,
|
||||
None,
|
||||
use_mla,
|
||||
has_sink,
|
||||
use_sparse,
|
||||
device_capability,
|
||||
)
|
||||
except ImportError:
|
||||
invalid_reasons = ["ImportError"]
|
||||
if invalid_reasons:
|
||||
raise ValueError(
|
||||
f"Attention backend {selected_backend} incompatible with MLA. "
|
||||
"Please use one of the MLA backends: FLASHINFER_MLA, CUTLASS_MLA, "
|
||||
"FLASHMLA, FLASH_ATTN_MLA, or TRITON_MLA. Alternatively, set "
|
||||
"VLLM_MLA_DISABLE=1 to disable MLA for this model."
|
||||
f"Selected backend {selected_backend} is not valid for "
|
||||
f"this configuration. Reason: {invalid_reasons}"
|
||||
)
|
||||
else:
|
||||
logger.info("Using %s backend.", selected_backend)
|
||||
return selected_backend.get_path()
|
||||
|
||||
from vllm.attention.ops.flashmla import is_flashmla_dense_supported
|
||||
from vllm.attention.utils.fa_utils import flash_attn_supports_mla
|
||||
|
||||
if use_sparse:
|
||||
logger.info_once("Using Sparse MLA backend.")
|
||||
return (
|
||||
"vllm.v1.attention.backends.mla.flashmla_sparse."
|
||||
"FlashMLASparseBackend"
|
||||
)
|
||||
|
||||
use_cutlassmla = selected_backend == _Backend.CUTLASS_MLA or (
|
||||
selected_backend is None
|
||||
and cls.is_device_capability(100)
|
||||
and block_size % 128 == 0
|
||||
)
|
||||
use_flashinfermla = selected_backend == _Backend.FLASHINFER_MLA or (
|
||||
selected_backend is None
|
||||
and cls.is_device_capability(100)
|
||||
and (block_size == 32 or block_size % 64 == 0)
|
||||
)
|
||||
use_flashmla = selected_backend == _Backend.FLASHMLA or (
|
||||
selected_backend is None and is_flashmla_dense_supported()[0]
|
||||
)
|
||||
use_flashattn = selected_backend == _Backend.FLASH_ATTN_MLA or (
|
||||
selected_backend is None and flash_attn_supports_mla()
|
||||
)
|
||||
use_triton = selected_backend == _Backend.TRITON_MLA or (
|
||||
selected_backend is None
|
||||
)
|
||||
|
||||
if use_cutlassmla:
|
||||
logger.info_once("Using Cutlass MLA backend.", scope="local")
|
||||
return "vllm.v1.attention.backends.mla.cutlass_mla.CutlassMLABackend"
|
||||
if use_flashinfermla:
|
||||
from vllm.v1.attention.backends.utils import set_kv_cache_layout
|
||||
|
||||
set_kv_cache_layout("HND")
|
||||
logger.info_once("Using FlashInfer MLA backend.")
|
||||
return (
|
||||
"vllm.v1.attention.backends.mla.flashinfer_mla.FlashInferMLABackend"
|
||||
)
|
||||
if use_flashmla:
|
||||
if block_size % 64 != 0:
|
||||
logger.warning(
|
||||
"FlashMLA backend is not supported for block size %d"
|
||||
" (currently only supports block size 64).",
|
||||
block_size,
|
||||
)
|
||||
else:
|
||||
logger.info_once("Using FlashMLA backend.")
|
||||
return "vllm.v1.attention.backends.mla.flashmla.FlashMLABackend"
|
||||
if use_flashattn:
|
||||
logger.info_once("Using FlashAttention MLA backend.")
|
||||
return (
|
||||
"vllm.v1.attention.backends.mla.flashattn_mla.FlashAttnMLABackend"
|
||||
)
|
||||
if use_triton:
|
||||
logger.info_once("Using Triton MLA backend.")
|
||||
return "vllm.v1.attention.backends.mla.triton_mla.TritonMLABackend"
|
||||
|
||||
FLASHINFER_V1 = "vllm.v1.attention.backends.flashinfer.FlashInferBackend" # noqa: E501
|
||||
FLEX_ATTENTION_V1 = (
|
||||
"vllm.v1.attention.backends.flex_attention.FlexAttentionBackend" # noqa: E501
|
||||
# No selected backend or the selected backend is invalid,
|
||||
# so we try finding a valid backend.
|
||||
valid_backends_priorities, invalid_reasons = cls.get_valid_backends(
|
||||
head_size,
|
||||
dtype,
|
||||
kv_cache_dtype,
|
||||
None,
|
||||
use_mla,
|
||||
has_sink,
|
||||
use_sparse,
|
||||
device_capability,
|
||||
)
|
||||
TRITON_ATTN = "vllm.v1.attention.backends.triton_attn.TritonAttentionBackend" # noqa: E501
|
||||
FLASH_ATTN_V1 = "vllm.v1.attention.backends.flash_attn.FlashAttentionBackend" # noqa: E501
|
||||
TREE_ATTN_V1 = "vllm.v1.attention.backends.tree_attn.TreeAttentionBackend" # noqa: E501
|
||||
XFORMERS_V1 = "vllm.v1.attention.backends.xformers.XFormersAttentionBackend" # noqa: E501
|
||||
reasons_str = (
|
||||
"{"
|
||||
+ ", ".join(
|
||||
f"{backend.name}: [{', '.join(reasons)}]"
|
||||
for backend, reasons in invalid_reasons.items()
|
||||
)
|
||||
+ "}"
|
||||
)
|
||||
config_str = (
|
||||
f"head_size: {head_size}, dtype: {dtype}, "
|
||||
f"kv_cache_dtype: {kv_cache_dtype}, block_size: {block_size}, "
|
||||
f"use_mla: {use_mla}, has_sink: {has_sink}, use_sparse: {use_sparse}"
|
||||
)
|
||||
logger.debug_once(
|
||||
f"Some attention backends are not valid for {cls.device_name} with "
|
||||
f"{config_str}. Reasons: {reasons_str}."
|
||||
)
|
||||
if len(valid_backends_priorities) == 0:
|
||||
raise ValueError(
|
||||
f"No valid attention backend found for {cls.device_name} "
|
||||
f"with {config_str}. Reasons: {reasons_str}."
|
||||
)
|
||||
|
||||
use_fp8_kv_cache = kv_cache_dtype is not None and kv_cache_dtype.startswith(
|
||||
"fp8"
|
||||
# We have found some valid backends. Select the one with the
|
||||
# highest priority.
|
||||
logger.info(
|
||||
"Valid backends: %s", [b[0].name for b in valid_backends_priorities]
|
||||
)
|
||||
sorted_indices = sorted(
|
||||
range(len(valid_backends_priorities)),
|
||||
key=lambda i: valid_backends_priorities[i][1],
|
||||
)
|
||||
selected_index = sorted_indices[0]
|
||||
selected_backend = valid_backends_priorities[selected_index][0]
|
||||
logger.info(
|
||||
"Using %s backend.",
|
||||
selected_backend.name,
|
||||
)
|
||||
|
||||
if selected_backend == _Backend.FLASHINFER:
|
||||
logger.info_once("Using FlashInfer backend.")
|
||||
if cls.has_device_capability(100):
|
||||
from vllm.v1.attention.backends.utils import set_kv_cache_layout
|
||||
|
||||
set_kv_cache_layout("HND")
|
||||
return FLASHINFER_V1
|
||||
elif selected_backend == _Backend.FLEX_ATTENTION:
|
||||
logger.info_once("Using FlexAttention backend.")
|
||||
return FLEX_ATTENTION_V1
|
||||
elif selected_backend == _Backend.TRITON_ATTN:
|
||||
logger.info_once("Using Triton backend.")
|
||||
return TRITON_ATTN
|
||||
elif selected_backend == _Backend.FLASH_ATTN:
|
||||
logger.info_once("Using Flash Attention backend.")
|
||||
return FLASH_ATTN_V1
|
||||
elif selected_backend == _Backend.TREE_ATTN:
|
||||
logger.info_once("Using Tree Attention backend.")
|
||||
return TREE_ATTN_V1
|
||||
elif selected_backend == _Backend.XFORMERS:
|
||||
logger.info_once("Using XFormers backend.")
|
||||
return XFORMERS_V1
|
||||
|
||||
from vllm.attention.selector import is_attn_backend_supported
|
||||
|
||||
# Default backends for V1 engine
|
||||
# Prefer FlashInfer for Blackwell GPUs if installed
|
||||
if cls.is_device_capability(100):
|
||||
if is_default_backend_supported := is_attn_backend_supported(
|
||||
FLASHINFER_V1, head_size, dtype
|
||||
):
|
||||
from vllm.v1.attention.backends.utils import set_kv_cache_layout
|
||||
|
||||
logger.info_once(
|
||||
"Using FlashInfer backend with HND KV cache layout on "
|
||||
"V1 engine by default for Blackwell (SM 10.0) GPUs."
|
||||
)
|
||||
set_kv_cache_layout("HND")
|
||||
|
||||
return FLASHINFER_V1
|
||||
|
||||
if not is_default_backend_supported.can_import:
|
||||
logger.warning_once(
|
||||
"FlashInfer failed to import on Blackwell (SM 10.0) GPUs; "
|
||||
"it is recommended to install FlashInfer for better "
|
||||
"performance."
|
||||
)
|
||||
|
||||
# FlashAttention is the default for SM 8.0+ GPUs
|
||||
if cls.has_device_capability(80):
|
||||
if (has_sink or use_fp8_kv_cache) and not cls.is_device_capability(90):
|
||||
logger.info_once("Using Triton backend.")
|
||||
return TRITON_ATTN
|
||||
elif is_default_backend_supported := is_attn_backend_supported(
|
||||
FLASH_ATTN_V1, head_size, dtype, allow_import_error=False
|
||||
):
|
||||
logger.info_once("Using Flash Attention backend.")
|
||||
return FLASH_ATTN_V1
|
||||
|
||||
# FlexAttention is the default for older GPUs
|
||||
else:
|
||||
logger.info_once("Using FlexAttention backend.")
|
||||
return FLEX_ATTENTION_V1
|
||||
|
||||
assert not is_default_backend_supported
|
||||
|
||||
use_flex_attention_reason = {}
|
||||
if not is_default_backend_supported.head_size:
|
||||
use_flex_attention_reason["head_size"] = head_size
|
||||
if not is_default_backend_supported.dtype:
|
||||
use_flex_attention_reason["dtype"] = dtype
|
||||
|
||||
logger.info_once(
|
||||
"Using FlexAttention backend for %s.",
|
||||
", ".join(f"{k}={v}" for k, v in use_flex_attention_reason.items()),
|
||||
)
|
||||
return FLEX_ATTENTION_V1
|
||||
return selected_backend.get_path()
|
||||
|
||||
@classmethod
|
||||
def get_punica_wrapper(cls) -> str:
|
||||
|
||||
@ -17,8 +17,9 @@ from vllm.logger import init_logger
|
||||
if TYPE_CHECKING:
|
||||
from torch.distributed import PrefixStore, ProcessGroup
|
||||
|
||||
from vllm.attention.backends.registry import _Backend
|
||||
from vllm.attention.backends.registry import AttentionBackendEnum
|
||||
from vllm.config import VllmConfig
|
||||
from vllm.config.cache import CacheDType
|
||||
from vllm.inputs import ProcessorInputs, PromptType
|
||||
from vllm.pooling_params import PoolingParams
|
||||
from vllm.sampling_params import SamplingParams
|
||||
@ -58,6 +59,31 @@ class DeviceCapability(NamedTuple):
|
||||
major: int
|
||||
minor: int
|
||||
|
||||
def __lt__(self, other: Any) -> bool:
|
||||
if not isinstance(other, DeviceCapability):
|
||||
return NotImplemented
|
||||
return (self.major, self.minor) < (other.major, other.minor)
|
||||
|
||||
def __le__(self, other: Any) -> bool:
|
||||
if not isinstance(other, DeviceCapability):
|
||||
return NotImplemented
|
||||
return (self.major, self.minor) <= (other.major, other.minor)
|
||||
|
||||
def __eq__(self, other: Any) -> bool:
|
||||
if not isinstance(other, DeviceCapability):
|
||||
return NotImplemented
|
||||
return (self.major, self.minor) == (other.major, other.minor)
|
||||
|
||||
def __ge__(self, other: Any) -> bool:
|
||||
if not isinstance(other, DeviceCapability):
|
||||
return NotImplemented
|
||||
return (self.major, self.minor) >= (other.major, other.minor)
|
||||
|
||||
def __gt__(self, other: Any) -> bool:
|
||||
if not isinstance(other, DeviceCapability):
|
||||
return NotImplemented
|
||||
return (self.major, self.minor) > (other.major, other.minor)
|
||||
|
||||
def as_version_str(self) -> str:
|
||||
return f"{self.major}.{self.minor}"
|
||||
|
||||
@ -173,19 +199,21 @@ class Platform:
|
||||
import vllm._moe_C # noqa: F401
|
||||
|
||||
@classmethod
|
||||
def get_vit_attn_backend(cls, head_size: int, dtype: torch.dtype) -> "_Backend":
|
||||
# Import _Backend here to avoid circular import.
|
||||
from vllm.attention.backends.registry import _Backend
|
||||
def get_vit_attn_backend(
|
||||
cls, head_size: int, dtype: torch.dtype
|
||||
) -> "AttentionBackendEnum":
|
||||
# Import AttentionBackendEnum here to avoid circular import.
|
||||
from vllm.attention.backends.registry import AttentionBackendEnum
|
||||
|
||||
return _Backend.TORCH_SDPA
|
||||
return AttentionBackendEnum.TORCH_SDPA
|
||||
|
||||
@classmethod
|
||||
def get_attn_backend_cls(
|
||||
cls,
|
||||
selected_backend: "_Backend",
|
||||
selected_backend: "AttentionBackendEnum",
|
||||
head_size: int,
|
||||
dtype: torch.dtype,
|
||||
kv_cache_dtype: str | None,
|
||||
kv_cache_dtype: "CacheDType | None",
|
||||
block_size: int,
|
||||
use_v1: bool,
|
||||
use_mla: bool,
|
||||
|
||||
@ -14,10 +14,10 @@ from vllm.utils.torch_utils import cuda_device_count_stateless
|
||||
from .interface import DeviceCapability, Platform, PlatformEnum
|
||||
|
||||
if TYPE_CHECKING:
|
||||
from vllm.attention.backends.registry import _Backend
|
||||
from vllm.attention.backends.registry import AttentionBackendEnum
|
||||
from vllm.config import VllmConfig
|
||||
else:
|
||||
_Backend = None
|
||||
AttentionBackendEnum = None
|
||||
|
||||
logger = init_logger(__name__)
|
||||
|
||||
@ -204,21 +204,23 @@ class RocmPlatform(Platform):
|
||||
]
|
||||
|
||||
@classmethod
|
||||
def get_vit_attn_backend(cls, head_size: int, dtype: torch.dtype) -> _Backend:
|
||||
def get_vit_attn_backend(
|
||||
cls, head_size: int, dtype: torch.dtype
|
||||
) -> AttentionBackendEnum:
|
||||
from importlib.util import find_spec
|
||||
|
||||
from vllm._aiter_ops import rocm_aiter_ops
|
||||
from vllm.attention.backends.registry import _Backend
|
||||
from vllm.attention.backends.registry import AttentionBackendEnum
|
||||
|
||||
if rocm_aiter_ops.is_mha_enabled():
|
||||
# Note: AITER FA is only supported for Qwen-VL models.
|
||||
# TODO: Add support for other VL models in their model class.
|
||||
return _Backend.ROCM_AITER_FA
|
||||
return AttentionBackendEnum.ROCM_AITER_FA
|
||||
|
||||
if on_gfx9() and find_spec("flash_attn") is not None:
|
||||
return _Backend.FLASH_ATTN
|
||||
return AttentionBackendEnum.FLASH_ATTN
|
||||
|
||||
return _Backend.TORCH_SDPA
|
||||
return AttentionBackendEnum.TORCH_SDPA
|
||||
|
||||
@classmethod
|
||||
def get_attn_backend_cls(
|
||||
@ -234,7 +236,7 @@ class RocmPlatform(Platform):
|
||||
use_sparse,
|
||||
) -> str:
|
||||
from vllm._aiter_ops import rocm_aiter_ops
|
||||
from vllm.attention.backends.registry import _Backend
|
||||
from vllm.attention.backends.registry import AttentionBackendEnum
|
||||
|
||||
if use_sparse:
|
||||
raise NotImplementedError("Sparse Attention is not supported on ROCm.")
|
||||
@ -248,55 +250,52 @@ class RocmPlatform(Platform):
|
||||
if use_mla:
|
||||
if selected_backend is None:
|
||||
selected_backend = (
|
||||
_Backend.ROCM_AITER_MLA
|
||||
AttentionBackendEnum.ROCM_AITER_MLA
|
||||
if rocm_aiter_ops.is_mla_enabled() or block_size == 1
|
||||
else _Backend.TRITON_MLA
|
||||
else AttentionBackendEnum.TRITON_MLA
|
||||
)
|
||||
|
||||
if selected_backend == _Backend.TRITON_MLA:
|
||||
if selected_backend == AttentionBackendEnum.TRITON_MLA:
|
||||
if block_size != 1:
|
||||
logger.info_once("Using Triton MLA backend.")
|
||||
return "vllm.v1.attention.backends.mla.triton_mla.TritonMLABackend"
|
||||
return AttentionBackendEnum.TRITON_MLA.get_path()
|
||||
raise ValueError(
|
||||
f" The selected backend, {selected_backend.name},"
|
||||
f"does not support block size {block_size}."
|
||||
)
|
||||
if selected_backend == _Backend.ROCM_AITER_MLA:
|
||||
if selected_backend == AttentionBackendEnum.ROCM_AITER_MLA:
|
||||
logger.info("Using AITER MLA backend.")
|
||||
return "vllm.v1.attention.backends.mla.rocm_aiter_mla.AiterMLABackend" # noqa: E501
|
||||
return AttentionBackendEnum.ROCM_AITER_MLA.get_path()
|
||||
|
||||
raise ValueError(
|
||||
f" The selected backend, {selected_backend.name},"
|
||||
f"is not MLA type while requested for MLA backend."
|
||||
)
|
||||
|
||||
if selected_backend == _Backend.FLEX_ATTENTION:
|
||||
if selected_backend == AttentionBackendEnum.FLEX_ATTENTION:
|
||||
logger.info("Using FlexAttention backend.")
|
||||
return "vllm.v1.attention.backends.flex_attention.FlexAttentionBackend"
|
||||
if (
|
||||
rocm_aiter_ops.is_mha_enabled()
|
||||
) or selected_backend == _Backend.ROCM_AITER_FA:
|
||||
) or selected_backend == AttentionBackendEnum.ROCM_AITER_FA:
|
||||
logger.info("Using Aiter Flash Attention backend.")
|
||||
return "vllm.v1.attention.backends.rocm_aiter_fa.AiterFlashAttentionBackend"
|
||||
return AttentionBackendEnum.ROCM_AITER_FA.get_path()
|
||||
if (
|
||||
rocm_aiter_ops.is_triton_unified_attn_enabled()
|
||||
) or selected_backend == _Backend.ROCM_AITER_UNIFIED_ATTN:
|
||||
) or selected_backend == AttentionBackendEnum.ROCM_AITER_UNIFIED_ATTN:
|
||||
logger.info("Using Aiter Unified Attention backend.")
|
||||
return (
|
||||
"vllm.v1.attention.backends."
|
||||
"rocm_aiter_unified_attn.RocmAiterUnifiedAttentionBackend"
|
||||
)
|
||||
return AttentionBackendEnum.ROCM_AITER_UNIFIED_ATTN.get_path()
|
||||
if (
|
||||
envs.VLLM_V1_USE_PREFILL_DECODE_ATTENTION
|
||||
or selected_backend == _Backend.ROCM_ATTN
|
||||
or selected_backend == AttentionBackendEnum.ROCM_ATTN
|
||||
):
|
||||
# rocm specific backend, with aiter and/or
|
||||
# triton prefix-prefill
|
||||
logger.info("Using Rocm Attention backend.")
|
||||
return "vllm.v1.attention.backends.rocm_attn.RocmAttentionBackend"
|
||||
return AttentionBackendEnum.ROCM_ATTN.get_path()
|
||||
# default case, using triton unified attention
|
||||
logger.info("Using Triton Attention backend.")
|
||||
return "vllm.v1.attention.backends.triton_attn.TritonAttentionBackend"
|
||||
return AttentionBackendEnum.TRITON_ATTN.get_path()
|
||||
|
||||
@classmethod
|
||||
def set_device(cls, device: torch.device) -> None:
|
||||
|
||||
@ -15,16 +15,15 @@ from vllm.utils import DEFAULT_MAX_NUM_BATCHED_TOKENS
|
||||
from .interface import Platform, PlatformEnum
|
||||
|
||||
if TYPE_CHECKING:
|
||||
from vllm.attention.backends.registry import _Backend
|
||||
from vllm.config import ModelConfig, VllmConfig
|
||||
from vllm.attention.backends.registry import AttentionBackendEnum
|
||||
from vllm.config import VllmConfig
|
||||
from vllm.config.cache import BlockSize
|
||||
from vllm.pooling_params import PoolingParams
|
||||
else:
|
||||
BlockSize = None
|
||||
ModelConfig = None
|
||||
VllmConfig = None
|
||||
PoolingParams = None
|
||||
_Backend = None
|
||||
AttentionBackendEnum = None
|
||||
|
||||
logger = init_logger(__name__)
|
||||
|
||||
@ -54,7 +53,7 @@ class TpuPlatform(Platform):
|
||||
@classmethod
|
||||
def get_attn_backend_cls(
|
||||
cls,
|
||||
selected_backend: "_Backend",
|
||||
selected_backend: "AttentionBackendEnum",
|
||||
head_size: int,
|
||||
dtype: torch.dtype,
|
||||
kv_cache_dtype: str | None,
|
||||
@ -64,17 +63,17 @@ class TpuPlatform(Platform):
|
||||
has_sink,
|
||||
use_sparse,
|
||||
) -> str:
|
||||
from vllm.attention.backends.registry import _Backend
|
||||
from vllm.attention.backends.registry import AttentionBackendEnum
|
||||
|
||||
if use_sparse:
|
||||
raise NotImplementedError("Sparse Attention is not supported on TPU.")
|
||||
if selected_backend != _Backend.PALLAS:
|
||||
if selected_backend != AttentionBackendEnum.PALLAS:
|
||||
logger.info("Cannot use %s backend on TPU.", selected_backend)
|
||||
|
||||
if not use_v1:
|
||||
raise ValueError("TPU backend only supports V1.")
|
||||
logger.info("Using Pallas V1 backend.")
|
||||
return "vllm.v1.attention.backends.pallas.PallasAttentionBackend"
|
||||
return AttentionBackendEnum.PALLAS.get_path()
|
||||
|
||||
@classmethod
|
||||
def set_device(cls, device: torch.device) -> None:
|
||||
|
||||
@ -14,12 +14,11 @@ from vllm.utils import DEFAULT_MAX_NUM_BATCHED_TOKENS
|
||||
from .interface import DeviceCapability, Platform, PlatformEnum
|
||||
|
||||
if TYPE_CHECKING:
|
||||
from vllm.attention.backends.registry import _Backend
|
||||
from vllm.config import ModelConfig, VllmConfig
|
||||
from vllm.attention.backends.registry import AttentionBackendEnum
|
||||
from vllm.config import VllmConfig
|
||||
else:
|
||||
ModelConfig = None
|
||||
VllmConfig = None
|
||||
_Backend = None
|
||||
AttentionBackendEnum = None
|
||||
|
||||
logger = init_logger(__name__)
|
||||
|
||||
@ -44,7 +43,7 @@ class XPUPlatform(Platform):
|
||||
@classmethod
|
||||
def get_attn_backend_cls(
|
||||
cls,
|
||||
selected_backend: "_Backend",
|
||||
selected_backend: "AttentionBackendEnum",
|
||||
head_size: int,
|
||||
dtype: torch.dtype,
|
||||
kv_cache_dtype: str | None,
|
||||
@ -62,18 +61,19 @@ class XPUPlatform(Platform):
|
||||
"only NHD layout is supported by XPU attention kernels."
|
||||
)
|
||||
|
||||
from vllm.attention.backends.registry import _Backend
|
||||
from vllm.attention.backends.registry import AttentionBackendEnum
|
||||
|
||||
if use_sparse:
|
||||
raise NotImplementedError("Sparse Attention is not supported on XPU.")
|
||||
TRITON_ATTN = "vllm.v1.attention.backends.triton_attn.TritonAttentionBackend" # noqa: E501
|
||||
FLASH_ATTN = "vllm.v1.attention.backends.flash_attn.FlashAttentionBackend" # noqa: E501
|
||||
if selected_backend == _Backend.TRITON_ATTN:
|
||||
use_v1 = envs.VLLM_USE_V1
|
||||
if not use_v1:
|
||||
raise ValueError("XPU backend only supports V1.")
|
||||
if selected_backend == AttentionBackendEnum.TRITON_ATTN:
|
||||
logger.info_once("Using Triton backend.")
|
||||
return TRITON_ATTN
|
||||
elif selected_backend == _Backend.FLASH_ATTN:
|
||||
return AttentionBackendEnum.TRITON_ATTN.get_path()
|
||||
elif selected_backend == AttentionBackendEnum.FLASH_ATTN:
|
||||
logger.info_once("Using Flash Attention backend.")
|
||||
return FLASH_ATTN
|
||||
return AttentionBackendEnum.FLASH_ATTN.get_path()
|
||||
elif selected_backend:
|
||||
raise ValueError(
|
||||
f"Invalid attention backend for {cls.device_name}, "
|
||||
@ -81,7 +81,7 @@ class XPUPlatform(Platform):
|
||||
)
|
||||
|
||||
logger.info("Using Flash Attention backend.")
|
||||
return "vllm.v1.attention.backends.flash_attn.FlashAttentionBackend"
|
||||
return AttentionBackendEnum.FLASH_ATTN.get_path()
|
||||
|
||||
@classmethod
|
||||
def set_device(cls, device: torch.device) -> None:
|
||||
@ -113,10 +113,10 @@ class XPUPlatform(Platform):
|
||||
return device_props.total_memory
|
||||
|
||||
@classmethod
|
||||
def get_vit_attn_backend(cls, head_size: int, dtype: torch.dtype) -> _Backend:
|
||||
from vllm.attention.backends.registry import _Backend
|
||||
|
||||
return _Backend.FLASH_ATTN
|
||||
def get_vit_attn_backend(
|
||||
cls, head_size: int, dtype: torch.dtype
|
||||
) -> AttentionBackendEnum:
|
||||
return AttentionBackendEnum.FLASH_ATTN
|
||||
|
||||
@classmethod
|
||||
def inference_mode(cls):
|
||||
|
||||
@ -1,7 +1,7 @@
|
||||
# SPDX-License-Identifier: Apache-2.0
|
||||
# SPDX-FileCopyrightText: Copyright contributors to the vLLM project
|
||||
from dataclasses import dataclass
|
||||
from typing import Optional
|
||||
from typing import ClassVar, Optional
|
||||
|
||||
import numpy as np
|
||||
import torch
|
||||
@ -40,23 +40,16 @@ logger = init_logger(__name__)
|
||||
|
||||
class TorchSDPABackend(AttentionBackend):
|
||||
accept_output_buffer: bool = False
|
||||
supported_dtypes: ClassVar[list[torch.dtype]] = [
|
||||
torch.float16,
|
||||
torch.bfloat16,
|
||||
torch.float32,
|
||||
]
|
||||
|
||||
@classmethod
|
||||
def get_supported_dtypes(cls) -> list[torch.dtype]:
|
||||
return [torch.float16, torch.bfloat16, torch.float32]
|
||||
|
||||
@classmethod
|
||||
def validate_head_size(cls, head_size: int) -> None:
|
||||
def get_supported_head_sizes(cls) -> list[int]:
|
||||
attn_impl = _get_paged_attn_impl()
|
||||
is_valid, supported_head_sizes = attn_impl.validate_head_size(head_size)
|
||||
if not is_valid:
|
||||
attn_type = cls.__name__.removesuffix("Backend")
|
||||
raise ValueError(
|
||||
f"Head size {head_size} is not supported by {attn_type}. "
|
||||
f"Supported head sizes are: {supported_head_sizes}. "
|
||||
"Set VLLM_ATTENTION_BACKEND=FLEX_ATTENTION to use "
|
||||
"FlexAttention backend which supports all head sizes."
|
||||
)
|
||||
return attn_impl.get_supported_head_sizes()
|
||||
|
||||
@staticmethod
|
||||
def get_name() -> str:
|
||||
@ -759,9 +752,8 @@ def _make_sliding_window_bias(
|
||||
|
||||
class _PagedAttention:
|
||||
@staticmethod
|
||||
def validate_head_size(head_size: int) -> tuple[bool, list[int]]:
|
||||
SUPPORT_HS = [32, 64, 80, 96, 112, 128, 192, 256]
|
||||
return head_size in SUPPORT_HS, SUPPORT_HS
|
||||
def get_supported_head_sizes() -> list[int]:
|
||||
return [32, 64, 80, 96, 112, 128, 192, 256]
|
||||
|
||||
@staticmethod
|
||||
def get_kv_cache_shape(
|
||||
@ -861,8 +853,8 @@ class _PagedAttention:
|
||||
|
||||
class _IPEXPagedAttention(_PagedAttention):
|
||||
@staticmethod
|
||||
def validate_head_size(head_size: int) -> tuple[bool, list[int]]:
|
||||
return True, []
|
||||
def get_supported_head_sizes() -> list[int]:
|
||||
return []
|
||||
|
||||
@staticmethod
|
||||
def split_kv_cache(
|
||||
|
||||
@ -3,6 +3,7 @@
|
||||
"""Attention layer with FlashAttention."""
|
||||
|
||||
from dataclasses import dataclass
|
||||
from typing import ClassVar
|
||||
|
||||
import numpy as np
|
||||
import torch
|
||||
@ -32,11 +33,13 @@ if is_flash_attn_varlen_func_available():
|
||||
reshape_and_cache_flash,
|
||||
)
|
||||
from vllm.config import VllmConfig, get_layers_from_vllm_config
|
||||
from vllm.config.cache import CacheDType
|
||||
from vllm.distributed.parallel_state import get_dcp_group
|
||||
from vllm.logger import init_logger
|
||||
from vllm.model_executor.layers.batch_invariant import (
|
||||
vllm_is_batch_invariant,
|
||||
)
|
||||
from vllm.platforms.interface import DeviceCapability
|
||||
from vllm.utils.math_utils import cdiv
|
||||
from vllm.v1.attention.backends.utils import (
|
||||
AttentionCGSupport,
|
||||
@ -52,34 +55,12 @@ logger = init_logger(__name__)
|
||||
|
||||
class FlashAttentionBackend(AttentionBackend):
|
||||
accept_output_buffer: bool = True
|
||||
|
||||
@classmethod
|
||||
def get_supported_dtypes(cls) -> list[torch.dtype]:
|
||||
return [torch.float16, torch.bfloat16]
|
||||
|
||||
@classmethod
|
||||
def get_supported_head_sizes(cls) -> list[int]:
|
||||
return [32, 64, 96, 128, 160, 192, 224, 256]
|
||||
|
||||
@staticmethod
|
||||
def get_supported_kernel_block_size() -> list[int | MultipleOf]:
|
||||
# NOTE(tdoublep): while in principle, FA supports
|
||||
# MultipleOf(16), these are the block sizes that do not
|
||||
# suffer from the NaN propagation problem described here:
|
||||
# https://github.com/Dao-AILab/flash-attention/issues/1974
|
||||
return [16, 32, 64]
|
||||
|
||||
@classmethod
|
||||
def validate_head_size(cls, head_size: int) -> None:
|
||||
supported_head_sizes = cls.get_supported_head_sizes()
|
||||
if head_size not in supported_head_sizes:
|
||||
attn_type = cls.__name__.removesuffix("Backend")
|
||||
raise ValueError(
|
||||
f"Head size {head_size} is not supported by {attn_type}. "
|
||||
f"Supported head sizes are: {supported_head_sizes}. "
|
||||
"Set VLLM_ATTENTION_BACKEND=FLEX_ATTENTION to use "
|
||||
"FlexAttention backend which supports all head sizes."
|
||||
)
|
||||
supported_dtypes: ClassVar[list[torch.dtype]] = [torch.float16, torch.bfloat16]
|
||||
# NOTE(tdoublep): while in principle, FA supports
|
||||
# MultipleOf(16), these are the block sizes that do not
|
||||
# suffer from the NaN propagation problem described here:
|
||||
# https://github.com/Dao-AILab/flash-attention/issues/1974
|
||||
supported_kernel_block_sizes: ClassVar[list[int | MultipleOf]] = [16, 32, 64]
|
||||
|
||||
@staticmethod
|
||||
def get_name() -> str:
|
||||
@ -125,6 +106,38 @@ class FlashAttentionBackend(AttentionBackend):
|
||||
else:
|
||||
raise ValueError(f"Unrecognized FP8 dtype: {kv_cache_dtype}")
|
||||
|
||||
@classmethod
|
||||
def get_supported_head_sizes(cls) -> list[int]:
|
||||
return [32, 64, 96, 128, 160, 192, 224, 256]
|
||||
|
||||
@classmethod
|
||||
def supports_kv_cache_dtype(cls, kv_cache_dtype: CacheDType | None) -> bool:
|
||||
if kv_cache_dtype is None:
|
||||
return True
|
||||
if kv_cache_dtype.startswith("fp8"):
|
||||
return flash_attn_supports_fp8()
|
||||
return kv_cache_dtype in ["auto"]
|
||||
|
||||
@classmethod
|
||||
def supports_compute_capability(cls, capability: DeviceCapability) -> bool:
|
||||
return capability >= DeviceCapability(8, 0)
|
||||
|
||||
@classmethod
|
||||
def supports_combination(
|
||||
cls,
|
||||
head_size: int,
|
||||
dtype: torch.dtype,
|
||||
kv_cache_dtype: CacheDType | None,
|
||||
block_size: int,
|
||||
use_mla: bool,
|
||||
has_sink: bool,
|
||||
use_sparse: bool,
|
||||
device_capability: DeviceCapability,
|
||||
) -> str | None:
|
||||
if has_sink and device_capability < DeviceCapability(9, 0):
|
||||
return "sink not supported on compute capability < 9.0"
|
||||
return None
|
||||
|
||||
|
||||
@dataclass
|
||||
class FlashAttentionMetadata:
|
||||
@ -481,8 +494,6 @@ class FlashAttentionImpl(AttentionImpl):
|
||||
|
||||
self.num_queries_per_kv = self.num_heads // self.num_kv_heads
|
||||
|
||||
FlashAttentionBackend.validate_head_size(head_size)
|
||||
|
||||
self.attn_type = attn_type
|
||||
self.vllm_flash_attn_version = get_flash_attn_version()
|
||||
# Cache the batch invariant result for use in forward passes
|
||||
|
||||
@ -23,6 +23,7 @@ from vllm.attention.backends.abstract import (
|
||||
MultipleOf,
|
||||
)
|
||||
from vllm.config import CUDAGraphMode, VllmConfig
|
||||
from vllm.config.cache import CacheDType
|
||||
from vllm.logger import init_logger
|
||||
from vllm.model_executor.layers.batch_invariant import (
|
||||
vllm_is_batch_invariant,
|
||||
@ -33,6 +34,7 @@ from vllm.model_executor.layers.quantization.utils.quant_utils import (
|
||||
kNvfp4Quant,
|
||||
)
|
||||
from vllm.platforms import current_platform
|
||||
from vllm.platforms.interface import DeviceCapability
|
||||
from vllm.triton_utils import tl, triton
|
||||
from vllm.utils.flashinfer import (
|
||||
can_use_trtllm_attention,
|
||||
@ -45,6 +47,7 @@ from vllm.v1.attention.backends.utils import (
|
||||
AttentionCGSupport,
|
||||
AttentionMetadataBuilder,
|
||||
CommonAttentionMetadata,
|
||||
KVCacheLayoutType,
|
||||
get_kv_cache_layout,
|
||||
get_per_layer_parameters,
|
||||
infer_global_hyperparameters,
|
||||
@ -158,34 +161,17 @@ def trtllm_prefill_attn_kvfp8_dequant(
|
||||
|
||||
class FlashInferBackend(AttentionBackend):
|
||||
accept_output_buffer: bool = True
|
||||
|
||||
@classmethod
|
||||
def get_supported_dtypes(cls) -> list[torch.dtype]:
|
||||
return [torch.float16, torch.bfloat16]
|
||||
|
||||
@classmethod
|
||||
def get_supported_head_sizes(cls) -> list[int]:
|
||||
# https://github.com/flashinfer-ai/flashinfer/blob/3d55c71a62052c590c130897d3a3db49b14fcc34/include/flashinfer/utils.cuh#L157
|
||||
return [64, 128, 256]
|
||||
|
||||
@staticmethod
|
||||
def get_supported_kernel_block_size() -> list[int | MultipleOf]:
|
||||
# Note: Not sure for all platforms,
|
||||
# but on Blackwell, only support a page size of
|
||||
# 16, 32, 64
|
||||
return [16, 32, 64]
|
||||
|
||||
@classmethod
|
||||
def validate_head_size(cls, head_size: int) -> None:
|
||||
supported_head_sizes = cls.get_supported_head_sizes()
|
||||
if head_size not in supported_head_sizes:
|
||||
attn_type = cls.__name__.removesuffix("Backend")
|
||||
raise ValueError(
|
||||
f"Head size {head_size} is not supported by {attn_type}. "
|
||||
f"Supported head sizes are: {supported_head_sizes}. "
|
||||
"Set VLLM_ATTENTION_BACKEND=FLEX_ATTENTION to use "
|
||||
"FlexAttention backend which supports all head sizes."
|
||||
)
|
||||
supported_dtypes: ClassVar[list[torch.dtype]] = [torch.float16, torch.bfloat16]
|
||||
# Note: Not sure for all platforms,
|
||||
# but on Blackwell, only support a page size of
|
||||
# 16, 32, 64
|
||||
supported_kernel_block_sizes: ClassVar[list[int | MultipleOf]] = [16, 32, 64]
|
||||
supported_kv_cache_dtypes: ClassVar[list[CacheDType]] = [
|
||||
"auto",
|
||||
"fp8",
|
||||
"fp8_e4m3",
|
||||
"fp8_e5m2",
|
||||
]
|
||||
|
||||
@staticmethod
|
||||
def get_name() -> str:
|
||||
@ -231,6 +217,26 @@ class FlashInferBackend(AttentionBackend):
|
||||
else:
|
||||
raise ValueError(f"Unrecognized FP8 dtype: {kv_cache_dtype}")
|
||||
|
||||
@classmethod
|
||||
def get_supported_head_sizes(cls) -> list[int]:
|
||||
# https://github.com/flashinfer-ai/flashinfer/blob/3d55c71a62052c590c130897d3a3db49b14fcc34/include/flashinfer/utils.cuh#L157
|
||||
return [64, 128, 256]
|
||||
|
||||
@classmethod
|
||||
def supports_compute_capability(cls, capability: DeviceCapability) -> bool:
|
||||
return capability >= DeviceCapability(7, 5) and capability <= DeviceCapability(
|
||||
12, 1
|
||||
)
|
||||
|
||||
@classmethod
|
||||
def get_required_kv_cache_layout(cls) -> KVCacheLayoutType | None:
|
||||
from vllm.platforms import current_platform
|
||||
|
||||
capability = current_platform.get_device_capability()
|
||||
if capability is not None and capability.major == 10:
|
||||
return "HND"
|
||||
return None
|
||||
|
||||
|
||||
@dataclass
|
||||
class FlashInferMetadata:
|
||||
@ -328,7 +334,6 @@ class FlashInferMetadataBuilder(AttentionMetadataBuilder[FlashInferMetadata]):
|
||||
)
|
||||
self.num_kv_heads = self.kv_cache_spec.num_kv_heads
|
||||
self.head_dim = self.kv_cache_spec.head_size
|
||||
FlashInferBackend.validate_head_size(self.head_dim)
|
||||
self.page_size = self.kv_cache_spec.block_size
|
||||
|
||||
self.cache_dtype = self.cache_config.cache_dtype
|
||||
|
||||
@ -4,6 +4,7 @@
|
||||
|
||||
import math
|
||||
from dataclasses import dataclass
|
||||
from typing import ClassVar
|
||||
|
||||
import torch
|
||||
import torch._dynamo.decorators
|
||||
@ -24,6 +25,7 @@ from vllm.attention.backends.abstract import (
|
||||
is_quantized_kv_cache,
|
||||
)
|
||||
from vllm.config import VllmConfig
|
||||
from vllm.config.cache import CacheDType
|
||||
from vllm.logger import init_logger
|
||||
from vllm.model_executor.layers.batch_invariant import (
|
||||
vllm_is_batch_invariant,
|
||||
@ -71,14 +73,12 @@ def pad_to_multiple(x: torch.Tensor, multiple: int, dim: int):
|
||||
|
||||
class FlexAttentionBackend(AttentionBackend):
|
||||
accept_output_buffer: bool = True
|
||||
|
||||
@classmethod
|
||||
def get_supported_dtypes(cls) -> list[torch.dtype]:
|
||||
return [torch.float16, torch.bfloat16, torch.float32]
|
||||
|
||||
@classmethod
|
||||
def validate_head_size(cls, head_size: int) -> None:
|
||||
return # FlexAttention supports any head size
|
||||
supported_dtypes: ClassVar[list[torch.dtype]] = [
|
||||
torch.float16,
|
||||
torch.bfloat16,
|
||||
torch.float32,
|
||||
]
|
||||
supported_kv_cache_dtypes: ClassVar[list[CacheDType]] = ["auto"]
|
||||
|
||||
@staticmethod
|
||||
def get_name() -> str:
|
||||
@ -106,6 +106,10 @@ class FlexAttentionBackend(AttentionBackend):
|
||||
def use_cascade_attention(*args, **kwargs) -> bool:
|
||||
return False
|
||||
|
||||
@classmethod
|
||||
def get_supported_head_sizes(cls) -> list[int]:
|
||||
return []
|
||||
|
||||
|
||||
# @torch.compile(fullgraph=True, mode="reduce-overhead")
|
||||
def physical_to_logical_mapping(
|
||||
@ -720,7 +724,6 @@ class FlexAttentionImpl(AttentionImpl):
|
||||
if kv_sharing_target_layer_name is not None:
|
||||
raise NotImplementedError("FlexAttention does not support kv sharing yet.")
|
||||
|
||||
FlexAttentionBackend.validate_head_size(head_size)
|
||||
if is_quantized_kv_cache(self.kv_cache_dtype):
|
||||
raise NotImplementedError(
|
||||
"FlexAttention does not support quantized kv-cache. Yet"
|
||||
|
||||
@ -308,25 +308,13 @@ class MLACommonBackend(AttentionBackend):
|
||||
) -> tuple[int, ...]:
|
||||
return (num_blocks, block_size, head_size)
|
||||
|
||||
@classmethod
|
||||
def get_supported_dtypes(cls) -> list[torch.dtype]:
|
||||
return [torch.float16, torch.bfloat16]
|
||||
|
||||
@classmethod
|
||||
def get_supported_head_sizes(cls) -> list[int]:
|
||||
return [576]
|
||||
|
||||
@classmethod
|
||||
def validate_head_size(cls, head_size: int) -> None:
|
||||
supported_head_sizes = cls.get_supported_head_sizes()
|
||||
if head_size not in supported_head_sizes:
|
||||
attn_type = cls.__name__.removesuffix("Backend")
|
||||
raise ValueError(
|
||||
f"Head size {head_size} is not supported by {attn_type}. "
|
||||
f"Supported head sizes are: {supported_head_sizes}. "
|
||||
"Set VLLM_ATTENTION_BACKEND=FLEX_ATTENTION to use "
|
||||
"FlexAttention backend which supports all head sizes."
|
||||
)
|
||||
def is_mla(cls) -> bool:
|
||||
return True
|
||||
|
||||
|
||||
@dataclass
|
||||
@ -425,8 +413,10 @@ class MLACommonMetadata(Generic[D]):
|
||||
) = None
|
||||
|
||||
def __post_init__(self):
|
||||
if self.head_dim is not None:
|
||||
MLACommonBackend.validate_head_size(self.head_dim)
|
||||
if self.head_dim is not None and not MLACommonBackend.supports_head_size(
|
||||
self.head_dim
|
||||
):
|
||||
raise ValueError(f"Head dimension {self.head_dim} is not supported by MLA.")
|
||||
|
||||
|
||||
M = TypeVar("M", bound=MLACommonMetadata)
|
||||
|
||||
@ -13,7 +13,9 @@ from vllm.attention.backends.abstract import (
|
||||
MultipleOf,
|
||||
is_quantized_kv_cache,
|
||||
)
|
||||
from vllm.config.cache import CacheDType
|
||||
from vllm.logger import init_logger
|
||||
from vllm.platforms.interface import DeviceCapability
|
||||
from vllm.v1.attention.backends.mla.common import (
|
||||
MLACommonBackend,
|
||||
MLACommonImpl,
|
||||
@ -33,6 +35,14 @@ class CutlassMLAMetadataBuilder(MLACommonMetadataBuilder[MLACommonMetadata]):
|
||||
|
||||
|
||||
class CutlassMLABackend(MLACommonBackend):
|
||||
supported_dtypes: ClassVar[list[torch.dtype]] = [torch.float16, torch.bfloat16]
|
||||
supported_kernel_block_sizes: ClassVar[list[int | MultipleOf]] = [128]
|
||||
supported_kv_cache_dtypes: ClassVar[list[CacheDType]] = [
|
||||
"auto",
|
||||
"fp8",
|
||||
"fp8_e4m3",
|
||||
]
|
||||
|
||||
@staticmethod
|
||||
def get_name() -> str:
|
||||
return "CUTLASS_MLA"
|
||||
@ -45,9 +55,9 @@ class CutlassMLABackend(MLACommonBackend):
|
||||
def get_builder_cls() -> type["CutlassMLAMetadataBuilder"]:
|
||||
return CutlassMLAMetadataBuilder
|
||||
|
||||
@staticmethod
|
||||
def get_supported_kernel_block_size() -> list[int | MultipleOf]:
|
||||
return [128]
|
||||
@classmethod
|
||||
def supports_compute_capability(cls, capability: DeviceCapability) -> bool:
|
||||
return capability.major == 10
|
||||
|
||||
|
||||
class SM100Workspace:
|
||||
|
||||
@ -10,6 +10,7 @@ from vllm import envs
|
||||
from vllm.attention.backends.abstract import (
|
||||
AttentionLayer,
|
||||
AttentionType,
|
||||
MultipleOf,
|
||||
is_quantized_kv_cache,
|
||||
)
|
||||
from vllm.attention.utils.fa_utils import (
|
||||
@ -17,10 +18,12 @@ from vllm.attention.utils.fa_utils import (
|
||||
get_flash_attn_version,
|
||||
)
|
||||
from vllm.config import VllmConfig
|
||||
from vllm.config.cache import CacheDType
|
||||
from vllm.logger import init_logger
|
||||
from vllm.model_executor.layers.batch_invariant import (
|
||||
vllm_is_batch_invariant,
|
||||
)
|
||||
from vllm.platforms.interface import DeviceCapability
|
||||
from vllm.v1.attention.backends.mla.common import (
|
||||
MLACommonBackend,
|
||||
MLACommonDecodeMetadata,
|
||||
@ -37,6 +40,10 @@ logger = init_logger(__name__)
|
||||
|
||||
|
||||
class FlashAttnMLABackend(MLACommonBackend):
|
||||
supported_dtypes: ClassVar[list[torch.dtype]] = [torch.float16, torch.bfloat16]
|
||||
supported_kernel_block_sizes: ClassVar[list[int | MultipleOf]] = [MultipleOf(16)]
|
||||
supported_kv_cache_dtypes: ClassVar[list[CacheDType]] = ["auto"]
|
||||
|
||||
@staticmethod
|
||||
def get_name() -> str:
|
||||
return "FLASH_ATTN_MLA"
|
||||
@ -49,6 +56,26 @@ class FlashAttnMLABackend(MLACommonBackend):
|
||||
def get_impl_cls() -> type["FlashAttnMLAImpl"]:
|
||||
return FlashAttnMLAImpl
|
||||
|
||||
@classmethod
|
||||
def supports_compute_capability(cls, capability: DeviceCapability) -> bool:
|
||||
return capability.major == 9
|
||||
|
||||
@classmethod
|
||||
def supports_combination(
|
||||
cls,
|
||||
head_size: int,
|
||||
dtype: torch.dtype,
|
||||
kv_cache_dtype: CacheDType | None,
|
||||
block_size: int,
|
||||
use_mla: bool,
|
||||
has_sink: bool,
|
||||
use_sparse: bool,
|
||||
device_capability: DeviceCapability,
|
||||
) -> str | None:
|
||||
if not flash_attn_supports_mla():
|
||||
return "FlashAttention MLA not supported on this device"
|
||||
return None
|
||||
|
||||
|
||||
@dataclass
|
||||
class FlashAttnMLADecodeMetadata(MLACommonDecodeMetadata):
|
||||
|
||||
@ -6,8 +6,14 @@ from typing import ClassVar
|
||||
import torch
|
||||
from flashinfer.decode import trtllm_batch_decode_with_kv_cache_mla
|
||||
|
||||
from vllm.attention.backends.abstract import AttentionLayer, AttentionType, MultipleOf
|
||||
from vllm.attention.backends.abstract import (
|
||||
AttentionLayer,
|
||||
AttentionType,
|
||||
MultipleOf,
|
||||
)
|
||||
from vllm.config.cache import CacheDType
|
||||
from vllm.logger import init_logger
|
||||
from vllm.platforms.interface import DeviceCapability
|
||||
from vllm.v1.attention.backends.mla.common import (
|
||||
MLACommonBackend,
|
||||
MLACommonImpl,
|
||||
@ -15,7 +21,7 @@ from vllm.v1.attention.backends.mla.common import (
|
||||
MLACommonMetadataBuilder,
|
||||
QueryLenSupport,
|
||||
)
|
||||
from vllm.v1.attention.backends.utils import AttentionCGSupport
|
||||
from vllm.v1.attention.backends.utils import AttentionCGSupport, KVCacheLayoutType
|
||||
|
||||
logger = init_logger(__name__)
|
||||
|
||||
@ -28,6 +34,14 @@ class FlashInferMLAMetadataBuilder(MLACommonMetadataBuilder[MLACommonMetadata]):
|
||||
|
||||
|
||||
class FlashInferMLABackend(MLACommonBackend):
|
||||
supported_dtypes: ClassVar[list[torch.dtype]] = [torch.float16, torch.bfloat16]
|
||||
supported_kernel_block_sizes: ClassVar[list[int | MultipleOf]] = [32, 64]
|
||||
supported_kv_cache_dtypes: ClassVar[list[CacheDType]] = [
|
||||
"auto",
|
||||
"fp8",
|
||||
"fp8_e4m3",
|
||||
]
|
||||
|
||||
@staticmethod
|
||||
def get_name() -> str:
|
||||
return "FLASHINFER_MLA"
|
||||
@ -41,8 +55,12 @@ class FlashInferMLABackend(MLACommonBackend):
|
||||
return FlashInferMLAMetadataBuilder
|
||||
|
||||
@classmethod
|
||||
def get_supported_kernel_block_size(cls) -> list[int | MultipleOf]:
|
||||
return [32, 64]
|
||||
def supports_compute_capability(cls, capability: DeviceCapability) -> bool:
|
||||
return capability.major == 10
|
||||
|
||||
@classmethod
|
||||
def get_required_kv_cache_layout(cls) -> "KVCacheLayoutType | None":
|
||||
return "HND"
|
||||
|
||||
|
||||
g_fi_workspace = torch.zeros(
|
||||
|
||||
@ -13,10 +13,12 @@ from vllm.attention.ops.flashmla import (
|
||||
is_flashmla_dense_supported,
|
||||
)
|
||||
from vllm.config import VllmConfig
|
||||
from vllm.config.cache import CacheDType
|
||||
from vllm.logger import init_logger
|
||||
from vllm.model_executor.layers.batch_invariant import (
|
||||
vllm_is_batch_invariant,
|
||||
)
|
||||
from vllm.platforms.interface import DeviceCapability
|
||||
from vllm.v1.attention.backends.mla.common import (
|
||||
MLACommonBackend,
|
||||
MLACommonDecodeMetadata,
|
||||
@ -36,6 +38,14 @@ logger = init_logger(__name__)
|
||||
|
||||
|
||||
class FlashMLABackend(MLACommonBackend):
|
||||
supported_dtypes: ClassVar[list[torch.dtype]] = [torch.float16, torch.bfloat16]
|
||||
supported_kernel_block_sizes: ClassVar[list[int | MultipleOf]] = [64]
|
||||
supported_kv_cache_dtypes: ClassVar[list[CacheDType]] = [
|
||||
"auto",
|
||||
"fp8",
|
||||
"fp8_e4m3",
|
||||
]
|
||||
|
||||
@staticmethod
|
||||
def get_name() -> str:
|
||||
return "FLASHMLA"
|
||||
@ -48,9 +58,30 @@ class FlashMLABackend(MLACommonBackend):
|
||||
def get_impl_cls() -> type["FlashMLAImpl"]:
|
||||
return FlashMLAImpl
|
||||
|
||||
@staticmethod
|
||||
def get_supported_kernel_block_size() -> list[int | MultipleOf]:
|
||||
return [64]
|
||||
@classmethod
|
||||
def supports_compute_capability(cls, capability: DeviceCapability) -> bool:
|
||||
return capability.major in [9, 10]
|
||||
|
||||
@classmethod
|
||||
def supports_combination(
|
||||
cls,
|
||||
head_size: int,
|
||||
dtype: torch.dtype,
|
||||
kv_cache_dtype: CacheDType | None,
|
||||
block_size: int,
|
||||
use_mla: bool,
|
||||
has_sink: bool,
|
||||
use_sparse: bool,
|
||||
device_capability: DeviceCapability,
|
||||
) -> str | None:
|
||||
if use_sparse:
|
||||
from vllm.attention.ops.flashmla import is_flashmla_sparse_supported
|
||||
|
||||
return is_flashmla_sparse_supported()[1]
|
||||
else:
|
||||
from vllm.attention.ops.flashmla import is_flashmla_dense_supported
|
||||
|
||||
return is_flashmla_dense_supported()[1]
|
||||
|
||||
|
||||
@dataclass
|
||||
|
||||
@ -10,6 +10,7 @@ from vllm import _custom_ops as ops
|
||||
from vllm.attention.backends.abstract import (
|
||||
AttentionBackend,
|
||||
AttentionLayer,
|
||||
MultipleOf,
|
||||
)
|
||||
from vllm.attention.backends.utils import get_mla_dims
|
||||
from vllm.attention.ops.flashmla import (
|
||||
@ -18,8 +19,10 @@ from vllm.attention.ops.flashmla import (
|
||||
get_mla_metadata,
|
||||
)
|
||||
from vllm.config import VllmConfig
|
||||
from vllm.config.cache import CacheDType
|
||||
from vllm.logger import init_logger
|
||||
from vllm.platforms import current_platform
|
||||
from vllm.platforms.interface import DeviceCapability
|
||||
from vllm.triton_utils import tl, triton
|
||||
from vllm.utils.math_utils import cdiv
|
||||
from vllm.v1.attention.backends.mla.common import MLACommonBaseImpl
|
||||
@ -51,6 +54,9 @@ structured as:
|
||||
|
||||
class FlashMLASparseBackend(AttentionBackend):
|
||||
accept_output_buffer: bool = True
|
||||
supported_dtypes: ClassVar[list[torch.dtype]] = [torch.bfloat16]
|
||||
supported_kernel_block_sizes: ClassVar[list[int | MultipleOf]] = [64]
|
||||
supported_kv_cache_dtypes: ClassVar[list[CacheDType]] = ["auto", "fp8_ds_mla"]
|
||||
|
||||
@staticmethod
|
||||
def get_name() -> str:
|
||||
@ -64,6 +70,22 @@ class FlashMLASparseBackend(AttentionBackend):
|
||||
def get_impl_cls() -> type["FlashMLASparseImpl"]:
|
||||
return FlashMLASparseImpl
|
||||
|
||||
@classmethod
|
||||
def get_supported_head_sizes(cls) -> list[int]:
|
||||
return [576]
|
||||
|
||||
@classmethod
|
||||
def is_mla(cls) -> bool:
|
||||
return True
|
||||
|
||||
@classmethod
|
||||
def is_sparse(cls) -> bool:
|
||||
return True
|
||||
|
||||
@classmethod
|
||||
def supports_compute_capability(cls, capability: DeviceCapability) -> bool:
|
||||
return capability.major in [9, 10]
|
||||
|
||||
@staticmethod
|
||||
def get_kv_cache_shape(
|
||||
num_blocks: int,
|
||||
@ -79,14 +101,6 @@ class FlashMLASparseBackend(AttentionBackend):
|
||||
else:
|
||||
return (num_blocks, block_size, head_size)
|
||||
|
||||
@classmethod
|
||||
def get_supported_dtypes(cls) -> list[torch.dtype]:
|
||||
return [torch.bfloat16]
|
||||
|
||||
@classmethod
|
||||
def get_supported_head_sizes(cls) -> list[int]:
|
||||
return [576]
|
||||
|
||||
|
||||
@dataclass
|
||||
class FlashMLASparseMetadata:
|
||||
|
||||
@ -23,6 +23,8 @@ logger = init_logger(__name__)
|
||||
|
||||
|
||||
class DeepseekV32IndexerBackend(AttentionBackend):
|
||||
supported_kernel_block_sizes: ClassVar[list[int | MultipleOf]] = [64]
|
||||
|
||||
@classmethod
|
||||
def get_supported_head_sizes(cls) -> list[int]:
|
||||
return [32, 64, 128]
|
||||
@ -46,10 +48,6 @@ class DeepseekV32IndexerBackend(AttentionBackend):
|
||||
def get_kv_cache_stride_order() -> tuple[int, ...]:
|
||||
return (0, 1, 2)
|
||||
|
||||
@classmethod
|
||||
def get_supported_kernel_block_size(cls) -> list[int | MultipleOf]:
|
||||
return [64]
|
||||
|
||||
|
||||
@dataclass
|
||||
class DeepseekV32IndexerPrefillChunkMetadata:
|
||||
|
||||
@ -1,6 +1,7 @@
|
||||
# SPDX-License-Identifier: Apache-2.0
|
||||
# SPDX-FileCopyrightText: Copyright contributors to the vLLM project
|
||||
|
||||
from typing import ClassVar
|
||||
|
||||
import torch
|
||||
|
||||
@ -12,11 +13,13 @@ from vllm.attention.backends.abstract import (
|
||||
)
|
||||
from vllm.attention.ops.triton_decode_attention import decode_attention_fwd
|
||||
from vllm.attention.ops.triton_flash_attention import triton_attention
|
||||
from vllm.config.cache import CacheDType
|
||||
from vllm.logger import init_logger
|
||||
from vllm.model_executor.layers.batch_invariant import (
|
||||
vllm_is_batch_invariant,
|
||||
)
|
||||
from vllm.platforms import current_platform
|
||||
from vllm.platforms.interface import DeviceCapability
|
||||
from vllm.triton_utils import HAS_TRITON
|
||||
from vllm.v1.attention.backends.mla.common import (
|
||||
MLACommonBackend,
|
||||
@ -28,6 +31,9 @@ logger = init_logger(__name__)
|
||||
|
||||
|
||||
class TritonMLABackend(MLACommonBackend):
|
||||
supported_dtypes: ClassVar[list[torch.dtype]] = [torch.float16, torch.bfloat16]
|
||||
supported_kv_cache_dtypes: ClassVar[list[CacheDType]] = ["auto"]
|
||||
|
||||
@staticmethod
|
||||
def get_name() -> str:
|
||||
return "TRITON_MLA"
|
||||
@ -36,6 +42,10 @@ class TritonMLABackend(MLACommonBackend):
|
||||
def get_impl_cls() -> type["TritonMLAImpl"]:
|
||||
return TritonMLAImpl
|
||||
|
||||
@classmethod
|
||||
def supports_compute_capability(cls, capability: DeviceCapability) -> bool:
|
||||
return True
|
||||
|
||||
|
||||
class TritonMLAImpl(MLACommonImpl[MLACommonMetadata]):
|
||||
can_return_lse_for_decode: bool = True
|
||||
|
||||
@ -3,6 +3,7 @@
|
||||
"""Attention layer with AiterFlashAttention."""
|
||||
|
||||
from dataclasses import dataclass
|
||||
from typing import ClassVar
|
||||
|
||||
import torch
|
||||
|
||||
@ -445,31 +446,13 @@ class AiterFlashAttentionMetadataBuilder(
|
||||
|
||||
class AiterFlashAttentionBackend(AttentionBackend):
|
||||
accept_output_buffer: bool = True
|
||||
|
||||
@classmethod
|
||||
def get_supported_dtypes(cls) -> list[torch.dtype]:
|
||||
return [torch.float16, torch.bfloat16]
|
||||
supported_dtypes: ClassVar[list[torch.dtype]] = [torch.float16, torch.bfloat16]
|
||||
supported_kernel_block_sizes: ClassVar[list[int | MultipleOf]] = [MultipleOf(16)]
|
||||
|
||||
@classmethod
|
||||
def get_supported_head_sizes(cls) -> list[int]:
|
||||
return [64, 128, 256]
|
||||
|
||||
@staticmethod
|
||||
def get_supported_kernel_block_size() -> list[int | MultipleOf]:
|
||||
return [MultipleOf(16)]
|
||||
|
||||
@classmethod
|
||||
def validate_head_size(cls, head_size: int) -> None:
|
||||
supported_head_sizes = cls.get_supported_head_sizes()
|
||||
if head_size not in supported_head_sizes:
|
||||
attn_type = cls.__name__.removesuffix("Backend")
|
||||
raise ValueError(
|
||||
f"Head size {head_size} is not supported by {attn_type}. "
|
||||
f"Supported head sizes are: {supported_head_sizes}. "
|
||||
"Set VLLM_ATTENTION_BACKEND=FLEX_ATTENTION to use "
|
||||
"FlexAttention backend which supports all head sizes."
|
||||
)
|
||||
|
||||
@staticmethod
|
||||
def get_name() -> str:
|
||||
return "FLASH_ATTN"
|
||||
@ -531,8 +514,6 @@ class AiterFlashAttentionImpl(AttentionImpl):
|
||||
assert self.num_heads % self.num_kv_heads == 0
|
||||
self.num_queries_per_kv = self.num_heads // self.num_kv_heads
|
||||
|
||||
AiterFlashAttentionBackend.validate_head_size(head_size)
|
||||
|
||||
if attn_type != AttentionType.DECODER:
|
||||
raise NotImplementedError(
|
||||
"Encoder self-attention and "
|
||||
|
||||
@ -152,10 +152,7 @@ class RocmAttentionMetadataBuilder(AttentionMetadataBuilder[RocmAttentionMetadat
|
||||
|
||||
class RocmAttentionBackend(AttentionBackend):
|
||||
accept_output_buffer: bool = True
|
||||
|
||||
@classmethod
|
||||
def get_supported_dtypes(cls) -> list[torch.dtype]:
|
||||
return [torch.float16, torch.bfloat16]
|
||||
supported_dtypes: ClassVar[list[torch.dtype]] = [torch.float16, torch.bfloat16]
|
||||
|
||||
@classmethod
|
||||
def get_supported_head_sizes(cls) -> list[int]:
|
||||
@ -163,12 +160,11 @@ class RocmAttentionBackend(AttentionBackend):
|
||||
|
||||
@classmethod
|
||||
def validate_head_size(cls, head_size: int) -> None:
|
||||
supported_head_sizes = cls.get_supported_head_sizes()
|
||||
if head_size not in supported_head_sizes:
|
||||
if not cls.supports_head_size(head_size):
|
||||
attn_type = cls.__name__.removesuffix("Backend")
|
||||
raise ValueError(
|
||||
f"Head size {head_size} is not supported by {attn_type}. "
|
||||
f"Supported head sizes are: {supported_head_sizes}. "
|
||||
f"Supported head sizes are: {cls.get_supported_head_sizes()}. "
|
||||
"Set VLLM_ATTENTION_BACKEND=FLEX_ATTENTION to use "
|
||||
"FlexAttention backend which supports all head sizes."
|
||||
)
|
||||
|
||||
@ -4,7 +4,7 @@
|
||||
|
||||
import ast
|
||||
from dataclasses import dataclass
|
||||
from typing import Optional
|
||||
from typing import ClassVar, Optional
|
||||
|
||||
import torch
|
||||
|
||||
@ -30,31 +30,13 @@ logger = init_logger(__name__)
|
||||
|
||||
class TreeAttentionBackend(AttentionBackend):
|
||||
accept_output_buffer: bool = True
|
||||
|
||||
@classmethod
|
||||
def get_supported_dtypes(cls) -> list[torch.dtype]:
|
||||
return [torch.float16, torch.bfloat16]
|
||||
supported_dtypes: ClassVar[list[torch.dtype]] = [torch.float16, torch.bfloat16]
|
||||
supported_kernel_block_sizes: ClassVar[list[int | MultipleOf]] = [MultipleOf(16)]
|
||||
|
||||
@classmethod
|
||||
def get_supported_head_sizes(cls) -> list[int]:
|
||||
return [32, 64, 96, 128, 160, 192, 224, 256]
|
||||
|
||||
@staticmethod
|
||||
def get_supported_kernel_block_size() -> list[int | MultipleOf]:
|
||||
return [MultipleOf(16)]
|
||||
|
||||
@classmethod
|
||||
def validate_head_size(cls, head_size: int) -> None:
|
||||
supported_head_sizes = cls.get_supported_head_sizes()
|
||||
if head_size not in supported_head_sizes:
|
||||
attn_type = cls.__name__.removesuffix("Backend")
|
||||
raise ValueError(
|
||||
f"Head size {head_size} is not supported by {attn_type}. "
|
||||
f"Supported head sizes are: {supported_head_sizes}. "
|
||||
"Set VLLM_ATTENTION_BACKEND=FLEX_ATTENTION to use "
|
||||
"FlexAttention backend which supports all head sizes."
|
||||
)
|
||||
|
||||
@staticmethod
|
||||
def get_name() -> str:
|
||||
return "TREE_ATTN"
|
||||
@ -331,8 +313,6 @@ class TreeAttentionImpl(AttentionImpl):
|
||||
else:
|
||||
self.sliding_window = (sliding_window - 1, 0)
|
||||
|
||||
TreeAttentionBackend.validate_head_size(head_size)
|
||||
|
||||
if attn_type != AttentionType.DECODER:
|
||||
raise NotImplementedError(
|
||||
"Encoder self-attention and "
|
||||
|
||||
@ -18,12 +18,14 @@ from vllm.attention.ops.triton_reshape_and_cache_flash import (
|
||||
)
|
||||
from vllm.attention.ops.triton_unified_attention import unified_attention
|
||||
from vllm.config import VllmConfig
|
||||
from vllm.config.cache import CacheDType
|
||||
from vllm.logger import init_logger
|
||||
from vllm.model_executor.layers.quantization.utils.quant_utils import (
|
||||
QuantKey,
|
||||
kFp8StaticTensorSym,
|
||||
)
|
||||
from vllm.platforms import current_platform
|
||||
from vllm.platforms.interface import DeviceCapability
|
||||
from vllm.v1.attention.backends.utils import (
|
||||
AttentionCGSupport,
|
||||
AttentionMetadataBuilder,
|
||||
@ -147,25 +149,18 @@ class TritonAttentionMetadataBuilder(AttentionMetadataBuilder[TritonAttentionMet
|
||||
|
||||
class TritonAttentionBackend(AttentionBackend):
|
||||
accept_output_buffer: bool = True
|
||||
|
||||
@classmethod
|
||||
def get_supported_dtypes(cls) -> list[torch.dtype]:
|
||||
return [torch.float16, torch.bfloat16, torch.float32]
|
||||
|
||||
@staticmethod
|
||||
def get_supported_kernel_block_size() -> list[int | MultipleOf]:
|
||||
return [MultipleOf(16)]
|
||||
|
||||
@classmethod
|
||||
def validate_head_size(cls, head_size: int) -> None:
|
||||
# Triton Attention supports any head size above 32
|
||||
if head_size < 32:
|
||||
raise ValueError(
|
||||
f"Head size {head_size} is not supported by TritonAttention."
|
||||
f"Head sizes need to be larger or equal 32 for this backend. "
|
||||
"Set VLLM_ATTENTION_BACKEND=FLEX_ATTENTION to use "
|
||||
"FlexAttention backend which supports all head sizes."
|
||||
)
|
||||
supported_dtypes: ClassVar[list[torch.dtype]] = [
|
||||
torch.float16,
|
||||
torch.bfloat16,
|
||||
torch.float32,
|
||||
]
|
||||
supported_kernel_block_sizes: ClassVar[list[int | MultipleOf]] = [MultipleOf(16)]
|
||||
supported_kv_cache_dtypes: ClassVar[list[CacheDType]] = [
|
||||
"auto",
|
||||
"fp8",
|
||||
"fp8_e4m3",
|
||||
"fp8_e5m2",
|
||||
]
|
||||
|
||||
@staticmethod
|
||||
def get_name() -> str:
|
||||
@ -195,6 +190,18 @@ class TritonAttentionBackend(AttentionBackend):
|
||||
def get_builder_cls() -> type["TritonAttentionMetadataBuilder"]:
|
||||
return TritonAttentionMetadataBuilder
|
||||
|
||||
@classmethod
|
||||
def supports_head_size(cls, head_size: int) -> bool:
|
||||
return head_size >= 32
|
||||
|
||||
@classmethod
|
||||
def supports_sink(cls) -> bool:
|
||||
return True
|
||||
|
||||
@classmethod
|
||||
def supports_compute_capability(cls, capability: DeviceCapability) -> bool:
|
||||
return True
|
||||
|
||||
|
||||
class TritonAttentionImpl(AttentionImpl):
|
||||
def fused_output_quant_supported(self, quant_key: QuantKey):
|
||||
@ -237,8 +244,6 @@ class TritonAttentionImpl(AttentionImpl):
|
||||
|
||||
self.num_queries_per_kv = self.num_heads // self.num_kv_heads
|
||||
|
||||
TritonAttentionBackend.validate_head_size(head_size)
|
||||
|
||||
if attn_type != AttentionType.DECODER:
|
||||
raise NotImplementedError(
|
||||
"Encoder self-attention and "
|
||||
|
||||
@ -3,7 +3,7 @@
|
||||
"""Attention layer with XFormersAttention."""
|
||||
|
||||
from dataclasses import dataclass
|
||||
from typing import Optional
|
||||
from typing import ClassVar, Optional
|
||||
|
||||
import torch
|
||||
|
||||
@ -41,10 +41,8 @@ logger = init_logger(__name__)
|
||||
|
||||
class XFormersAttentionBackend(AttentionBackend):
|
||||
accept_output_buffer: bool = True
|
||||
|
||||
@classmethod
|
||||
def get_supported_dtypes(cls) -> list[torch.dtype]:
|
||||
return [torch.float16, torch.bfloat16]
|
||||
supported_dtypes: ClassVar[list[torch.dtype]] = [torch.float16, torch.bfloat16]
|
||||
supported_kernel_block_sizes: ClassVar[list[int | MultipleOf]] = [MultipleOf(16)]
|
||||
|
||||
@classmethod
|
||||
def get_supported_head_sizes(cls) -> list[int]:
|
||||
@ -80,22 +78,6 @@ class XFormersAttentionBackend(AttentionBackend):
|
||||
256,
|
||||
]
|
||||
|
||||
@staticmethod
|
||||
def get_supported_kernel_block_size() -> list[int | MultipleOf]:
|
||||
return [MultipleOf(16)]
|
||||
|
||||
@classmethod
|
||||
def validate_head_size(cls, head_size: int) -> None:
|
||||
supported_head_sizes = cls.get_supported_head_sizes()
|
||||
if head_size not in supported_head_sizes:
|
||||
attn_type = cls.__name__.removesuffix("Backend")
|
||||
raise ValueError(
|
||||
f"Head size {head_size} is not supported by {attn_type}. "
|
||||
f"Supported head sizes are: {supported_head_sizes}. "
|
||||
"Set VLLM_ATTENTION_BACKEND=FLEX_ATTENTION to use "
|
||||
"FlexAttention backend which supports all head sizes."
|
||||
)
|
||||
|
||||
@staticmethod
|
||||
def get_name() -> str:
|
||||
return "XFORMERS"
|
||||
@ -305,8 +287,6 @@ class XFormersAttentionImpl(AttentionImpl):
|
||||
logits_soft_cap = 0
|
||||
self.logits_soft_cap = logits_soft_cap
|
||||
|
||||
XFormersAttentionBackend.validate_head_size(head_size)
|
||||
|
||||
if attn_type != AttentionType.DECODER:
|
||||
raise NotImplementedError(
|
||||
"Encoder self-attention and "
|
||||
|
||||
@ -150,11 +150,15 @@ class EagleProposer:
|
||||
)
|
||||
|
||||
# Determine allowed attention backends once during initialization.
|
||||
from vllm.attention.backends.registry import AttentionBackendEnum
|
||||
|
||||
self.allowed_attn_types: tuple | None = None
|
||||
if current_platform.is_rocm():
|
||||
rocm_types = [TritonAttentionMetadata, FlashAttentionMetadata]
|
||||
# vllm.v1.attention.backends.rocm_aiter_fa is an optional backend
|
||||
if find_spec("vllm.v1.attention.backends.rocm_aiter_fa"):
|
||||
# ROCM_AITER_FA is an optional backend
|
||||
if find_spec(
|
||||
AttentionBackendEnum.ROCM_AITER_FA.get_path(include_classname=False)
|
||||
):
|
||||
from vllm.v1.attention.backends.rocm_aiter_fa import (
|
||||
AiterFlashAttentionMetadata,
|
||||
)
|
||||
|
||||
@ -4371,7 +4371,7 @@ class GPUModelRunner(LoRAModelRunnerMixin, KVConnectorModelRunnerMixin):
|
||||
"""
|
||||
for backend in backends:
|
||||
is_supported = False
|
||||
for supported_size in backend.get_supported_kernel_block_size():
|
||||
for supported_size in backend.supported_kernel_block_sizes:
|
||||
if isinstance(supported_size, int):
|
||||
if block_size == supported_size:
|
||||
is_supported = True
|
||||
@ -4402,7 +4402,7 @@ class GPUModelRunner(LoRAModelRunnerMixin, KVConnectorModelRunnerMixin):
|
||||
all_int_supported_sizes = set(
|
||||
supported_size
|
||||
for backend in backends
|
||||
for supported_size in backend.get_supported_kernel_block_size()
|
||||
for supported_size in backend.supported_kernel_block_sizes
|
||||
if isinstance(supported_size, int)
|
||||
)
|
||||
|
||||
|
||||
Loading…
x
Reference in New Issue
Block a user