[Attention] Implement universal BACKEND_MAP (#25900)

Signed-off-by: Matthew Bonanni <mbonanni@redhat.com>
This commit is contained in:
Matthew Bonanni 2025-10-08 15:00:25 -04:00 committed by GitHub
parent b25d7b5657
commit 76879cc160
No known key found for this signature in database
GPG Key ID: B5690EEEBB952194
12 changed files with 119 additions and 75 deletions

View File

@ -34,7 +34,7 @@ DEVICE_MLA_BACKENDS = {
DEVICE_REGULAR_ATTN_BACKENDS = {
"cuda": ["XFORMERS", "FLASHINFER", "FLASH_ATTN"],
"hip": ["ROCM_FLASH"],
"hip": ["ROCM_ATTN"],
"cpu": ["TORCH_SDPA"],
}
@ -122,7 +122,7 @@ def test_env(
backend = get_attn_backend(
16, torch.float16, None, block_size, use_mla=use_mla
)
expected = "TRITON_ATTN"
expected = "ROCM_ATTN"
assert backend.get_name() == expected
elif device == "cuda":

View File

@ -18,7 +18,7 @@ def clear_cache():
@pytest.mark.skip(reason="Skipped for now. Should be revisited.")
def test_selector(monkeypatch: pytest.MonkeyPatch):
with monkeypatch.context() as m:
m.setenv(STR_BACKEND_ENV_VAR, "ROCM_FLASH")
m.setenv(STR_BACKEND_ENV_VAR, "ROCM_ATTN")
# Set the current platform to ROCm using monkeypatch
monkeypatch.setattr("vllm.attention.selector.current_platform", RocmPlatform())

View File

@ -14,7 +14,7 @@ from tests.v1.attention.utils import (
create_common_attn_metadata,
create_standard_kv_cache_spec,
create_vllm_config,
get_attention_backend,
try_get_attention_backend,
)
from vllm.attention.backends.registry import _Backend
from vllm.config import ModelConfig
@ -214,7 +214,7 @@ def run_attention_backend(
actual_backend = _Backend.FLEX_ATTENTION
use_direct_block_mask = False
builder_cls, impl_cls = get_attention_backend(actual_backend)
builder_cls, impl_cls = try_get_attention_backend(actual_backend)
# Mock flashinfer's get_per_layer_parameters if needed
if actual_backend == _Backend.FLASHINFER:

View File

@ -12,7 +12,7 @@ from tests.v1.attention.utils import (
create_common_attn_metadata,
create_standard_kv_cache_spec,
create_vllm_config,
get_attention_backend,
try_get_attention_backend,
)
from vllm import _custom_ops as ops
from vllm.attention.backends.registry import _Backend
@ -239,7 +239,7 @@ def run_attention_backend(
) -> torch.Tensor:
"""Run attention computation using the specified backend's AttentionImpl."""
builder_cls, impl_cls = get_attention_backend(backend)
builder_cls, impl_cls = try_get_attention_backend(backend)
# Build metadata
builder = builder_cls(kv_cache_spec, layer_names, vllm_config, device)
@ -400,7 +400,7 @@ def test_backend_correctness(dist_init, batch_spec_name: str, model: str):
# Determine if this is decode or prefill
is_decode = []
for i, backend in enumerate(BACKENDS_TO_TEST):
builder_cls, _ = get_attention_backend(backend)
builder_cls, _ = try_get_attention_backend(backend)
is_decode.append(q_len <= builder_cls.reorder_batch_threshold)
# Split q into nope and rope components

View File

@ -8,7 +8,8 @@ from typing import Optional, Union
import pytest
import torch
from vllm.attention.backends.registry import _Backend
from vllm.attention.backends.abstract import AttentionImpl
from vllm.attention.backends.registry import _Backend, backend_to_class_str
from vllm.config import (
CacheConfig,
CompilationConfig,
@ -20,9 +21,11 @@ from vllm.config import (
VllmConfig,
)
from vllm.config.model import ModelDType
from vllm.platforms import current_platform
from vllm.utils import resolve_obj_by_qualname
from vllm.v1.attention.backends.utils import CommonAttentionMetadata
from vllm.v1.attention.backends.utils import (
AttentionMetadataBuilder,
CommonAttentionMetadata,
)
from vllm.v1.kv_cache_interface import FullAttentionSpec
@ -117,44 +120,17 @@ def create_common_attn_metadata(
)
def get_attention_backend(backend_name: _Backend):
"""Set up attention backend classes for testing.
Args:
backend_name: Name of the backend ("flash_attn", "flashinfer", etc.)
vllm_config: VllmConfig instance
Returns:
Tuple of (backend_builder_class, backend_impl_class)
"""
backend_map = {
_Backend.FLASH_ATTN: (
"vllm.v1.attention.backends.flash_attn.FlashAttentionBackend"
if current_platform.is_cuda()
else "vllm.v1.attention.backends.rocm_aiter_fa.AiterFlashAttentionBackend"
),
_Backend.FLASHINFER: "vllm.v1.attention.backends.flashinfer.FlashInferBackend",
_Backend.FLEX_ATTENTION: "vllm.v1.attention.backends.flex_attention.FlexAttentionBackend", # noqa: E501
_Backend.TRITON_ATTN: "vllm.v1.attention.backends.triton_attn.TritonAttentionBackend", # noqa: E501
_Backend.TREE_ATTN: "vllm.v1.attention.backends.tree_attn.TreeAttentionBackend",
_Backend.XFORMERS: "vllm.v1.attention.backends.xformers.XFormersAttentionBackend", # noqa: E501
_Backend.CUTLASS_MLA: "vllm.v1.attention.backends.mla.cutlass_mla.CutlassMLABackend", # noqa: E501
_Backend.FLASHMLA: "vllm.v1.attention.backends.mla.flashmla.FlashMLABackend",
_Backend.FLASH_ATTN_MLA: "vllm.v1.attention.backends.mla.flashattn_mla.FlashAttnMLABackend", # noqa: E501
_Backend.FLASHINFER_MLA: "vllm.v1.attention.backends.mla.flashinfer_mla.FlashInferMLABackend", # noqa: E501
_Backend.TRITON_MLA: "vllm.v1.attention.backends.mla.triton_mla.TritonMLABackend", # noqa: E501
}
if backend_name not in backend_map:
raise ValueError(f"Unknown backend: {backend_name}")
backend_class_name = backend_map[backend_name]
def try_get_attention_backend(
backend: _Backend,
) -> tuple[type[AttentionMetadataBuilder], type[AttentionImpl]]:
"""Try to get the attention backend class, skipping test if not found."""
backend_class_str = backend_to_class_str(backend)
try:
backend_class = resolve_obj_by_qualname(backend_class_name)
backend_class = resolve_obj_by_qualname(backend_class_str)
return backend_class.get_builder_cls(), backend_class.get_impl_cls()
except ImportError as e:
pytest.skip(f"{backend_name} not available: {e}")
pytest.skip(f"{backend_class_str} not available: {e}")
raise AssertionError("unreachable") from None
def create_standard_kv_cache_spec(vllm_config: VllmConfig) -> FullAttentionSpec:

View File

@ -12,7 +12,7 @@ from tests.v1.attention.utils import (
BatchSpec,
create_common_attn_metadata,
create_standard_kv_cache_spec,
get_attention_backend,
try_get_attention_backend,
)
from vllm.attention.backends.registry import _Backend
from vllm.config import (
@ -535,11 +535,11 @@ def test_propose(method, attn_backend, num_speculative_tokens, monkeypatch):
sampling_metadata = mock.MagicMock()
if attn_backend == "FLASH_ATTN":
attn_metadata_builder_cls, _ = get_attention_backend(_Backend.FLASH_ATTN)
attn_metadata_builder_cls, _ = try_get_attention_backend(_Backend.FLASH_ATTN)
elif attn_backend == "TRITON_ATTN":
attn_metadata_builder_cls, _ = get_attention_backend(_Backend.TRITON_ATTN)
attn_metadata_builder_cls, _ = try_get_attention_backend(_Backend.TRITON_ATTN)
elif attn_backend == "TREE_ATTN":
attn_metadata_builder_cls, _ = get_attention_backend(_Backend.TREE_ATTN)
attn_metadata_builder_cls, _ = try_get_attention_backend(_Backend.TREE_ATTN)
else:
raise ValueError(f"Unsupported attention backend: {attn_backend}")
@ -674,7 +674,7 @@ def test_propose_tree(spec_token_tree):
proposer.attn_layer_names = ["layer.0"]
# Get the tree attention metadata builder.
attn_metadata_builder_cls, _ = get_attention_backend(_Backend.TREE_ATTN)
attn_metadata_builder_cls, _ = try_get_attention_backend(_Backend.TREE_ATTN)
attn_metadata_builder = attn_metadata_builder_cls(
kv_cache_spec=create_standard_kv_cache_spec(proposer.vllm_config),
layer_names=proposer.attn_layer_names,

View File

@ -10,7 +10,7 @@ from tests.v1.attention.utils import (
BatchSpec,
create_common_attn_metadata,
create_standard_kv_cache_spec,
get_attention_backend,
try_get_attention_backend,
)
from vllm.attention.backends.registry import _Backend
from vllm.config import (
@ -177,7 +177,7 @@ def test_mtp_propose(num_speculative_tokens, monkeypatch):
sampling_metadata = mock.MagicMock()
# Setup attention metadata
attn_metadata_builder_cls, _ = get_attention_backend(_Backend.FLASH_ATTN)
attn_metadata_builder_cls, _ = try_get_attention_backend(_Backend.FLASH_ATTN)
attn_metadata_builder = attn_metadata_builder_cls(
kv_cache_spec=create_standard_kv_cache_spec(proposer.vllm_config),

View File

@ -9,7 +9,7 @@ import torch
from tests.v1.attention.utils import (
create_standard_kv_cache_spec,
create_vllm_config,
get_attention_backend,
try_get_attention_backend,
)
from vllm.attention.backends.registry import _Backend
from vllm.config import ParallelConfig, SpeculativeConfig
@ -63,7 +63,7 @@ def forward_attention(
# Build common metadata.
model_name = "meta-llama/Meta-Llama-3-8B"
builder_cls, impl_cls = get_attention_backend(backend)
builder_cls, impl_cls = try_get_attention_backend(backend)
vllm_config = create_vllm_config(model_name=model_name, max_model_len=max(seq_lens))
if spec_token_tree is not None:
# Create speculative config if token tree is specified.

View File

@ -3,13 +3,16 @@
"""Attention backend registry"""
import enum
from typing import Optional
from vllm.utils import resolve_obj_by_qualname
class _Backend(enum.Enum):
FLASH_ATTN = enum.auto()
TRITON_ATTN = enum.auto()
XFORMERS = enum.auto()
ROCM_FLASH = enum.auto()
ROCM_ATTN = enum.auto()
ROCM_AITER_MLA = enum.auto()
ROCM_AITER_FA = enum.auto() # used for ViT attn backend
TORCH_SDPA = enum.auto()
@ -24,5 +27,83 @@ class _Backend(enum.Enum):
NO_ATTENTION = enum.auto()
FLEX_ATTENTION = enum.auto()
TREE_ATTN = enum.auto()
ROCM_ATTN = enum.auto()
ROCM_AITER_UNIFIED_ATTN = enum.auto()
BACKEND_MAP = {
_Backend.FLASH_ATTN: "vllm.v1.attention.backends.flash_attn.FlashAttentionBackend", # noqa: E501
_Backend.TRITON_ATTN: "vllm.v1.attention.backends.triton_attn.TritonAttentionBackend", # noqa: E501
_Backend.XFORMERS: "vllm.v1.attention.backends.xformers.XFormersAttentionBackend", # noqa: E501
_Backend.ROCM_ATTN: "vllm.v1.attention.backends.rocm_attn.RocmAttentionBackend", # noqa: E501
_Backend.ROCM_AITER_MLA: "vllm.v1.attention.backends.mla.rocm_aiter_mla.AiterMLABackend", # noqa: E501
_Backend.ROCM_AITER_FA: "vllm.v1.attention.backends.rocm_aiter_fa.AiterFlashAttentionBackend", # noqa: E501
_Backend.TORCH_SDPA: "vllm.v1.attention.backends.cpu_attn.TorchSDPABackend", # noqa: E501
_Backend.FLASHINFER: "vllm.v1.attention.backends.flashinfer.FlashInferBackend", # noqa: E501
_Backend.FLASHINFER_MLA: "vllm.v1.attention.backends.mla.flashinfer_mla.FlashInferMLABackend", # noqa: E501
_Backend.TRITON_MLA: "vllm.v1.attention.backends.mla.triton_mla.TritonMLABackend", # noqa: E501
_Backend.CUTLASS_MLA: "vllm.v1.attention.backends.mla.cutlass_mla.CutlassMLABackend", # noqa: E501
_Backend.FLASHMLA: "vllm.v1.attention.backends.mla.flashmla.FlashMLABackend", # noqa: E501
_Backend.FLASH_ATTN_MLA: "vllm.v1.attention.backends.mla.flashattn_mla.FlashAttnMLABackend", # noqa: E501
_Backend.PALLAS: "vllm.v1.attention.backends.pallas.PallasAttentionBackend", # noqa: E501
_Backend.FLEX_ATTENTION: "vllm.v1.attention.backends.flex_attention.FlexAttentionBackend", # noqa: E501
_Backend.TREE_ATTN: "vllm.v1.attention.backends.tree_attn.TreeAttentionBackend", # noqa: E501
_Backend.ROCM_AITER_UNIFIED_ATTN: "vllm.v1.attention.backends.rocm_aiter_unified_attn.RocmAiterUnifiedAttentionBackend", # noqa: E501
}
def register_attn_backend(backend: _Backend, class_path: Optional[str] = None):
"""
Decorator: register a custom attention backend into BACKEND_MAPPING.
- If class_path is provided, use it.
- Otherwise, auto-generate from the class object.
Validation: only checks if 'backend' is a valid _Backend enum member.
Overwriting existing mappings is allowed. This enables other hardware
platforms to plug in custom out-of-tree backends.
"""
if not isinstance(backend, _Backend):
raise ValueError(f"{backend} is not a valid _Backend enum value.")
def decorator(cls):
path = class_path or f"{cls.__module__}.{cls.__qualname__}"
BACKEND_MAP[backend] = path
return cls
return decorator
def backend_to_class_str(backend: _Backend) -> str:
"""Get the backend class string
Args:
backend: The backend enum value
Returns:
The backend class string
"""
return BACKEND_MAP[backend]
def backend_to_class(backend: _Backend) -> type:
"""Get the backend class.
Args:
backend: The backend enum value
Returns:
The backend class
"""
backend_class_name = backend_to_class_str(backend)
return resolve_obj_by_qualname(backend_class_name)
def backend_name_to_enum(backend_name: str) -> Optional[_Backend]:
"""
Convert a string backend name to a _Backend enum value.
Returns:
_Backend: enum value if backend_name is a valid in-tree type
None: otherwise it's an invalid in-tree type or an out-of-tree platform
is loaded.
"""
assert backend_name is not None
return _Backend[backend_name] if backend_name in _Backend.__members__ else None

View File

@ -11,8 +11,8 @@ import torch.nn.functional as F
import vllm.envs as envs
from vllm.attention import AttentionType
from vllm.attention.backends.abstract import AttentionBackend
from vllm.attention.backends.registry import _Backend
from vllm.attention.selector import backend_name_to_enum, get_attn_backend
from vllm.attention.backends.registry import _Backend, backend_name_to_enum
from vllm.attention.selector import get_attn_backend
from vllm.attention.utils.kv_sharing_utils import validate_kv_sharing_target
from vllm.config import CacheConfig, get_current_vllm_config
from vllm.distributed.kv_transfer import (

View File

@ -12,7 +12,7 @@ import torch
import vllm.envs as envs
from vllm.attention.backends.abstract import AttentionBackend
from vllm.attention.backends.registry import _Backend
from vllm.attention.backends.registry import _Backend, backend_name_to_enum
from vllm.logger import init_logger
from vllm.platforms import current_platform
from vllm.utils import STR_BACKEND_ENV_VAR, resolve_obj_by_qualname
@ -20,19 +20,6 @@ from vllm.utils import STR_BACKEND_ENV_VAR, resolve_obj_by_qualname
logger = init_logger(__name__)
def backend_name_to_enum(backend_name: str) -> Optional[_Backend]:
"""
Convert a string backend name to a _Backend enum value.
Returns:
* _Backend: enum value if backend_name is a valid in-tree type
* None: otherwise it's an invalid in-tree type or an out-of-tree platform is
loaded.
"""
assert backend_name is not None
return _Backend[backend_name] if backend_name in _Backend.__members__ else None
def get_env_variable_attn_backend() -> Optional[_Backend]:
"""
Get the backend override specified by the vLLM attention

View File

@ -21,8 +21,8 @@ import torch
import zmq
from vllm import envs
from vllm.attention.backends.registry import _Backend
from vllm.attention.selector import backend_name_to_enum, get_attn_backend
from vllm.attention.backends.registry import _Backend, backend_name_to_enum
from vllm.attention.selector import get_attn_backend
from vllm.config import VllmConfig
from vllm.distributed.kv_transfer.kv_connector.v1.base import (
CopyBlocksOp,