diff --git a/tests/kernels/attention/test_attention_selector.py b/tests/kernels/attention/test_attention_selector.py index fa95c3b2d39ea..20061ad2f8bf7 100644 --- a/tests/kernels/attention/test_attention_selector.py +++ b/tests/kernels/attention/test_attention_selector.py @@ -34,7 +34,7 @@ DEVICE_MLA_BACKENDS = { DEVICE_REGULAR_ATTN_BACKENDS = { "cuda": ["XFORMERS", "FLASHINFER", "FLASH_ATTN"], - "hip": ["ROCM_FLASH"], + "hip": ["ROCM_ATTN"], "cpu": ["TORCH_SDPA"], } @@ -122,7 +122,7 @@ def test_env( backend = get_attn_backend( 16, torch.float16, None, block_size, use_mla=use_mla ) - expected = "TRITON_ATTN" + expected = "ROCM_ATTN" assert backend.get_name() == expected elif device == "cuda": diff --git a/tests/kernels/attention/test_rocm_attention_selector.py b/tests/kernels/attention/test_rocm_attention_selector.py index a59230528770c..9b7fb664956c6 100644 --- a/tests/kernels/attention/test_rocm_attention_selector.py +++ b/tests/kernels/attention/test_rocm_attention_selector.py @@ -18,7 +18,7 @@ def clear_cache(): @pytest.mark.skip(reason="Skipped for now. Should be revisited.") def test_selector(monkeypatch: pytest.MonkeyPatch): with monkeypatch.context() as m: - m.setenv(STR_BACKEND_ENV_VAR, "ROCM_FLASH") + m.setenv(STR_BACKEND_ENV_VAR, "ROCM_ATTN") # Set the current platform to ROCm using monkeypatch monkeypatch.setattr("vllm.attention.selector.current_platform", RocmPlatform()) diff --git a/tests/v1/attention/test_attention_backends.py b/tests/v1/attention/test_attention_backends.py index 188482e071ee6..7fee73da15a2a 100644 --- a/tests/v1/attention/test_attention_backends.py +++ b/tests/v1/attention/test_attention_backends.py @@ -14,7 +14,7 @@ from tests.v1.attention.utils import ( create_common_attn_metadata, create_standard_kv_cache_spec, create_vllm_config, - get_attention_backend, + try_get_attention_backend, ) from vllm.attention.backends.registry import _Backend from vllm.config import ModelConfig @@ -214,7 +214,7 @@ def run_attention_backend( actual_backend = _Backend.FLEX_ATTENTION use_direct_block_mask = False - builder_cls, impl_cls = get_attention_backend(actual_backend) + builder_cls, impl_cls = try_get_attention_backend(actual_backend) # Mock flashinfer's get_per_layer_parameters if needed if actual_backend == _Backend.FLASHINFER: diff --git a/tests/v1/attention/test_mla_backends.py b/tests/v1/attention/test_mla_backends.py index debaa6a5e0096..3b6a9115435c4 100644 --- a/tests/v1/attention/test_mla_backends.py +++ b/tests/v1/attention/test_mla_backends.py @@ -12,7 +12,7 @@ from tests.v1.attention.utils import ( create_common_attn_metadata, create_standard_kv_cache_spec, create_vllm_config, - get_attention_backend, + try_get_attention_backend, ) from vllm import _custom_ops as ops from vllm.attention.backends.registry import _Backend @@ -239,7 +239,7 @@ def run_attention_backend( ) -> torch.Tensor: """Run attention computation using the specified backend's AttentionImpl.""" - builder_cls, impl_cls = get_attention_backend(backend) + builder_cls, impl_cls = try_get_attention_backend(backend) # Build metadata builder = builder_cls(kv_cache_spec, layer_names, vllm_config, device) @@ -400,7 +400,7 @@ def test_backend_correctness(dist_init, batch_spec_name: str, model: str): # Determine if this is decode or prefill is_decode = [] for i, backend in enumerate(BACKENDS_TO_TEST): - builder_cls, _ = get_attention_backend(backend) + builder_cls, _ = try_get_attention_backend(backend) is_decode.append(q_len <= builder_cls.reorder_batch_threshold) # Split q into nope and rope components diff --git a/tests/v1/attention/utils.py b/tests/v1/attention/utils.py index a22f32c9a31ca..819cd81be358d 100644 --- a/tests/v1/attention/utils.py +++ b/tests/v1/attention/utils.py @@ -8,7 +8,8 @@ from typing import Optional, Union import pytest import torch -from vllm.attention.backends.registry import _Backend +from vllm.attention.backends.abstract import AttentionImpl +from vllm.attention.backends.registry import _Backend, backend_to_class_str from vllm.config import ( CacheConfig, CompilationConfig, @@ -20,9 +21,11 @@ from vllm.config import ( VllmConfig, ) from vllm.config.model import ModelDType -from vllm.platforms import current_platform from vllm.utils import resolve_obj_by_qualname -from vllm.v1.attention.backends.utils import CommonAttentionMetadata +from vllm.v1.attention.backends.utils import ( + AttentionMetadataBuilder, + CommonAttentionMetadata, +) from vllm.v1.kv_cache_interface import FullAttentionSpec @@ -117,44 +120,17 @@ def create_common_attn_metadata( ) -def get_attention_backend(backend_name: _Backend): - """Set up attention backend classes for testing. - - Args: - backend_name: Name of the backend ("flash_attn", "flashinfer", etc.) - vllm_config: VllmConfig instance - - Returns: - Tuple of (backend_builder_class, backend_impl_class) - """ - backend_map = { - _Backend.FLASH_ATTN: ( - "vllm.v1.attention.backends.flash_attn.FlashAttentionBackend" - if current_platform.is_cuda() - else "vllm.v1.attention.backends.rocm_aiter_fa.AiterFlashAttentionBackend" - ), - _Backend.FLASHINFER: "vllm.v1.attention.backends.flashinfer.FlashInferBackend", - _Backend.FLEX_ATTENTION: "vllm.v1.attention.backends.flex_attention.FlexAttentionBackend", # noqa: E501 - _Backend.TRITON_ATTN: "vllm.v1.attention.backends.triton_attn.TritonAttentionBackend", # noqa: E501 - _Backend.TREE_ATTN: "vllm.v1.attention.backends.tree_attn.TreeAttentionBackend", - _Backend.XFORMERS: "vllm.v1.attention.backends.xformers.XFormersAttentionBackend", # noqa: E501 - _Backend.CUTLASS_MLA: "vllm.v1.attention.backends.mla.cutlass_mla.CutlassMLABackend", # noqa: E501 - _Backend.FLASHMLA: "vllm.v1.attention.backends.mla.flashmla.FlashMLABackend", - _Backend.FLASH_ATTN_MLA: "vllm.v1.attention.backends.mla.flashattn_mla.FlashAttnMLABackend", # noqa: E501 - _Backend.FLASHINFER_MLA: "vllm.v1.attention.backends.mla.flashinfer_mla.FlashInferMLABackend", # noqa: E501 - _Backend.TRITON_MLA: "vllm.v1.attention.backends.mla.triton_mla.TritonMLABackend", # noqa: E501 - } - - if backend_name not in backend_map: - raise ValueError(f"Unknown backend: {backend_name}") - - backend_class_name = backend_map[backend_name] - +def try_get_attention_backend( + backend: _Backend, +) -> tuple[type[AttentionMetadataBuilder], type[AttentionImpl]]: + """Try to get the attention backend class, skipping test if not found.""" + backend_class_str = backend_to_class_str(backend) try: - backend_class = resolve_obj_by_qualname(backend_class_name) + backend_class = resolve_obj_by_qualname(backend_class_str) return backend_class.get_builder_cls(), backend_class.get_impl_cls() except ImportError as e: - pytest.skip(f"{backend_name} not available: {e}") + pytest.skip(f"{backend_class_str} not available: {e}") + raise AssertionError("unreachable") from None def create_standard_kv_cache_spec(vllm_config: VllmConfig) -> FullAttentionSpec: diff --git a/tests/v1/spec_decode/test_eagle.py b/tests/v1/spec_decode/test_eagle.py index 4c490f2188aa2..0f0a3722ef2dd 100644 --- a/tests/v1/spec_decode/test_eagle.py +++ b/tests/v1/spec_decode/test_eagle.py @@ -12,7 +12,7 @@ from tests.v1.attention.utils import ( BatchSpec, create_common_attn_metadata, create_standard_kv_cache_spec, - get_attention_backend, + try_get_attention_backend, ) from vllm.attention.backends.registry import _Backend from vllm.config import ( @@ -535,11 +535,11 @@ def test_propose(method, attn_backend, num_speculative_tokens, monkeypatch): sampling_metadata = mock.MagicMock() if attn_backend == "FLASH_ATTN": - attn_metadata_builder_cls, _ = get_attention_backend(_Backend.FLASH_ATTN) + attn_metadata_builder_cls, _ = try_get_attention_backend(_Backend.FLASH_ATTN) elif attn_backend == "TRITON_ATTN": - attn_metadata_builder_cls, _ = get_attention_backend(_Backend.TRITON_ATTN) + attn_metadata_builder_cls, _ = try_get_attention_backend(_Backend.TRITON_ATTN) elif attn_backend == "TREE_ATTN": - attn_metadata_builder_cls, _ = get_attention_backend(_Backend.TREE_ATTN) + attn_metadata_builder_cls, _ = try_get_attention_backend(_Backend.TREE_ATTN) else: raise ValueError(f"Unsupported attention backend: {attn_backend}") @@ -674,7 +674,7 @@ def test_propose_tree(spec_token_tree): proposer.attn_layer_names = ["layer.0"] # Get the tree attention metadata builder. - attn_metadata_builder_cls, _ = get_attention_backend(_Backend.TREE_ATTN) + attn_metadata_builder_cls, _ = try_get_attention_backend(_Backend.TREE_ATTN) attn_metadata_builder = attn_metadata_builder_cls( kv_cache_spec=create_standard_kv_cache_spec(proposer.vllm_config), layer_names=proposer.attn_layer_names, diff --git a/tests/v1/spec_decode/test_mtp.py b/tests/v1/spec_decode/test_mtp.py index d7d9ef07e46ca..9ca7cf9e3e0e1 100644 --- a/tests/v1/spec_decode/test_mtp.py +++ b/tests/v1/spec_decode/test_mtp.py @@ -10,7 +10,7 @@ from tests.v1.attention.utils import ( BatchSpec, create_common_attn_metadata, create_standard_kv_cache_spec, - get_attention_backend, + try_get_attention_backend, ) from vllm.attention.backends.registry import _Backend from vllm.config import ( @@ -177,7 +177,7 @@ def test_mtp_propose(num_speculative_tokens, monkeypatch): sampling_metadata = mock.MagicMock() # Setup attention metadata - attn_metadata_builder_cls, _ = get_attention_backend(_Backend.FLASH_ATTN) + attn_metadata_builder_cls, _ = try_get_attention_backend(_Backend.FLASH_ATTN) attn_metadata_builder = attn_metadata_builder_cls( kv_cache_spec=create_standard_kv_cache_spec(proposer.vllm_config), diff --git a/tests/v1/spec_decode/test_tree_attention.py b/tests/v1/spec_decode/test_tree_attention.py index a46e8e3ec7556..b31a2f27f54b0 100644 --- a/tests/v1/spec_decode/test_tree_attention.py +++ b/tests/v1/spec_decode/test_tree_attention.py @@ -9,7 +9,7 @@ import torch from tests.v1.attention.utils import ( create_standard_kv_cache_spec, create_vllm_config, - get_attention_backend, + try_get_attention_backend, ) from vllm.attention.backends.registry import _Backend from vllm.config import ParallelConfig, SpeculativeConfig @@ -63,7 +63,7 @@ def forward_attention( # Build common metadata. model_name = "meta-llama/Meta-Llama-3-8B" - builder_cls, impl_cls = get_attention_backend(backend) + builder_cls, impl_cls = try_get_attention_backend(backend) vllm_config = create_vllm_config(model_name=model_name, max_model_len=max(seq_lens)) if spec_token_tree is not None: # Create speculative config if token tree is specified. diff --git a/vllm/attention/backends/registry.py b/vllm/attention/backends/registry.py index 06f13044d5722..b74ae09e61126 100644 --- a/vllm/attention/backends/registry.py +++ b/vllm/attention/backends/registry.py @@ -3,13 +3,16 @@ """Attention backend registry""" import enum +from typing import Optional + +from vllm.utils import resolve_obj_by_qualname class _Backend(enum.Enum): FLASH_ATTN = enum.auto() TRITON_ATTN = enum.auto() XFORMERS = enum.auto() - ROCM_FLASH = enum.auto() + ROCM_ATTN = enum.auto() ROCM_AITER_MLA = enum.auto() ROCM_AITER_FA = enum.auto() # used for ViT attn backend TORCH_SDPA = enum.auto() @@ -24,5 +27,83 @@ class _Backend(enum.Enum): NO_ATTENTION = enum.auto() FLEX_ATTENTION = enum.auto() TREE_ATTN = enum.auto() - ROCM_ATTN = enum.auto() ROCM_AITER_UNIFIED_ATTN = enum.auto() + + +BACKEND_MAP = { + _Backend.FLASH_ATTN: "vllm.v1.attention.backends.flash_attn.FlashAttentionBackend", # noqa: E501 + _Backend.TRITON_ATTN: "vllm.v1.attention.backends.triton_attn.TritonAttentionBackend", # noqa: E501 + _Backend.XFORMERS: "vllm.v1.attention.backends.xformers.XFormersAttentionBackend", # noqa: E501 + _Backend.ROCM_ATTN: "vllm.v1.attention.backends.rocm_attn.RocmAttentionBackend", # noqa: E501 + _Backend.ROCM_AITER_MLA: "vllm.v1.attention.backends.mla.rocm_aiter_mla.AiterMLABackend", # noqa: E501 + _Backend.ROCM_AITER_FA: "vllm.v1.attention.backends.rocm_aiter_fa.AiterFlashAttentionBackend", # noqa: E501 + _Backend.TORCH_SDPA: "vllm.v1.attention.backends.cpu_attn.TorchSDPABackend", # noqa: E501 + _Backend.FLASHINFER: "vllm.v1.attention.backends.flashinfer.FlashInferBackend", # noqa: E501 + _Backend.FLASHINFER_MLA: "vllm.v1.attention.backends.mla.flashinfer_mla.FlashInferMLABackend", # noqa: E501 + _Backend.TRITON_MLA: "vllm.v1.attention.backends.mla.triton_mla.TritonMLABackend", # noqa: E501 + _Backend.CUTLASS_MLA: "vllm.v1.attention.backends.mla.cutlass_mla.CutlassMLABackend", # noqa: E501 + _Backend.FLASHMLA: "vllm.v1.attention.backends.mla.flashmla.FlashMLABackend", # noqa: E501 + _Backend.FLASH_ATTN_MLA: "vllm.v1.attention.backends.mla.flashattn_mla.FlashAttnMLABackend", # noqa: E501 + _Backend.PALLAS: "vllm.v1.attention.backends.pallas.PallasAttentionBackend", # noqa: E501 + _Backend.FLEX_ATTENTION: "vllm.v1.attention.backends.flex_attention.FlexAttentionBackend", # noqa: E501 + _Backend.TREE_ATTN: "vllm.v1.attention.backends.tree_attn.TreeAttentionBackend", # noqa: E501 + _Backend.ROCM_AITER_UNIFIED_ATTN: "vllm.v1.attention.backends.rocm_aiter_unified_attn.RocmAiterUnifiedAttentionBackend", # noqa: E501 +} + + +def register_attn_backend(backend: _Backend, class_path: Optional[str] = None): + """ + Decorator: register a custom attention backend into BACKEND_MAPPING. + - If class_path is provided, use it. + - Otherwise, auto-generate from the class object. + Validation: only checks if 'backend' is a valid _Backend enum member. + Overwriting existing mappings is allowed. This enables other hardware + platforms to plug in custom out-of-tree backends. + """ + if not isinstance(backend, _Backend): + raise ValueError(f"{backend} is not a valid _Backend enum value.") + + def decorator(cls): + path = class_path or f"{cls.__module__}.{cls.__qualname__}" + BACKEND_MAP[backend] = path + return cls + + return decorator + + +def backend_to_class_str(backend: _Backend) -> str: + """Get the backend class string + + Args: + backend: The backend enum value + + Returns: + The backend class string + """ + return BACKEND_MAP[backend] + + +def backend_to_class(backend: _Backend) -> type: + """Get the backend class. + + Args: + backend: The backend enum value + + Returns: + The backend class + """ + backend_class_name = backend_to_class_str(backend) + return resolve_obj_by_qualname(backend_class_name) + + +def backend_name_to_enum(backend_name: str) -> Optional[_Backend]: + """ + Convert a string backend name to a _Backend enum value. + + Returns: + _Backend: enum value if backend_name is a valid in-tree type + None: otherwise it's an invalid in-tree type or an out-of-tree platform + is loaded. + """ + assert backend_name is not None + return _Backend[backend_name] if backend_name in _Backend.__members__ else None diff --git a/vllm/attention/layer.py b/vllm/attention/layer.py index 6994debd4589a..b429c74aa5597 100644 --- a/vllm/attention/layer.py +++ b/vllm/attention/layer.py @@ -11,8 +11,8 @@ import torch.nn.functional as F import vllm.envs as envs from vllm.attention import AttentionType from vllm.attention.backends.abstract import AttentionBackend -from vllm.attention.backends.registry import _Backend -from vllm.attention.selector import backend_name_to_enum, get_attn_backend +from vllm.attention.backends.registry import _Backend, backend_name_to_enum +from vllm.attention.selector import get_attn_backend from vllm.attention.utils.kv_sharing_utils import validate_kv_sharing_target from vllm.config import CacheConfig, get_current_vllm_config from vllm.distributed.kv_transfer import ( diff --git a/vllm/attention/selector.py b/vllm/attention/selector.py index 3a5bbb9972863..53677372e0551 100644 --- a/vllm/attention/selector.py +++ b/vllm/attention/selector.py @@ -12,7 +12,7 @@ import torch import vllm.envs as envs from vllm.attention.backends.abstract import AttentionBackend -from vllm.attention.backends.registry import _Backend +from vllm.attention.backends.registry import _Backend, backend_name_to_enum from vllm.logger import init_logger from vllm.platforms import current_platform from vllm.utils import STR_BACKEND_ENV_VAR, resolve_obj_by_qualname @@ -20,19 +20,6 @@ from vllm.utils import STR_BACKEND_ENV_VAR, resolve_obj_by_qualname logger = init_logger(__name__) -def backend_name_to_enum(backend_name: str) -> Optional[_Backend]: - """ - Convert a string backend name to a _Backend enum value. - - Returns: - * _Backend: enum value if backend_name is a valid in-tree type - * None: otherwise it's an invalid in-tree type or an out-of-tree platform is - loaded. - """ - assert backend_name is not None - return _Backend[backend_name] if backend_name in _Backend.__members__ else None - - def get_env_variable_attn_backend() -> Optional[_Backend]: """ Get the backend override specified by the vLLM attention diff --git a/vllm/distributed/kv_transfer/kv_connector/v1/nixl_connector.py b/vllm/distributed/kv_transfer/kv_connector/v1/nixl_connector.py index e3e3389fd1643..0d4744b9f4ab5 100644 --- a/vllm/distributed/kv_transfer/kv_connector/v1/nixl_connector.py +++ b/vllm/distributed/kv_transfer/kv_connector/v1/nixl_connector.py @@ -21,8 +21,8 @@ import torch import zmq from vllm import envs -from vllm.attention.backends.registry import _Backend -from vllm.attention.selector import backend_name_to_enum, get_attn_backend +from vllm.attention.backends.registry import _Backend, backend_name_to_enum +from vllm.attention.selector import get_attn_backend from vllm.config import VllmConfig from vllm.distributed.kv_transfer.kv_connector.v1.base import ( CopyBlocksOp,