# SPDX-License-Identifier: Apache-2.0 # SPDX-FileCopyrightText: Copyright contributors to the vLLM project """Attention backend registry""" from collections.abc import Callable from enum import Enum, EnumMeta from typing import TYPE_CHECKING, cast from vllm.logger import init_logger from vllm.utils.import_utils import resolve_obj_by_qualname if TYPE_CHECKING: from vllm.attention.backends.abstract import AttentionBackend logger = init_logger(__name__) class _AttentionBackendEnumMeta(EnumMeta): """Metaclass for AttentionBackendEnum to provide better error messages.""" def __getitem__(cls, name: str): """Get backend by name with helpful error messages.""" try: return super().__getitem__(name) except KeyError: members = cast("dict[str, Enum]", cls.__members__).keys() valid_backends = ", ".join(members) raise ValueError( f"Unknown attention backend: '{name}'. " f"Valid options are: {valid_backends}" ) from None class AttentionBackendEnum(Enum, metaclass=_AttentionBackendEnumMeta): """Enumeration of all supported attention backends. The enum value is the default class path, but this can be overridden at runtime using register_backend(). To get the actual backend class (respecting overrides), use: backend.get_class() """ FLASH_ATTN = "vllm.v1.attention.backends.flash_attn.FlashAttentionBackend" TRITON_ATTN = "vllm.v1.attention.backends.triton_attn.TritonAttentionBackend" ROCM_ATTN = "vllm.v1.attention.backends.rocm_attn.RocmAttentionBackend" ROCM_AITER_MLA = "vllm.v1.attention.backends.mla.rocm_aiter_mla.AiterMLABackend" ROCM_AITER_TRITON_MLA = ( "vllm.v1.attention.backends.mla.aiter_triton_mla.AiterTritonMLABackend" ) ROCM_AITER_FA = ( "vllm.v1.attention.backends.rocm_aiter_fa.AiterFlashAttentionBackend" ) ROCM_AITER_MLA_SPARSE = ( "vllm.v1.attention.backends.mla.rocm_aiter_mla_sparse.ROCMAiterMLASparseBackend" ) TORCH_SDPA = "" # this tag is only used for ViT FLASHINFER = "vllm.v1.attention.backends.flashinfer.FlashInferBackend" FLASHINFER_MLA = ( "vllm.v1.attention.backends.mla.flashinfer_mla.FlashInferMLABackend" ) TRITON_MLA = "vllm.v1.attention.backends.mla.triton_mla.TritonMLABackend" CUTLASS_MLA = "vllm.v1.attention.backends.mla.cutlass_mla.CutlassMLABackend" FLASHMLA = "vllm.v1.attention.backends.mla.flashmla.FlashMLABackend" FLASHMLA_SPARSE = ( "vllm.v1.attention.backends.mla.flashmla_sparse.FlashMLASparseBackend" ) FLASH_ATTN_MLA = "vllm.v1.attention.backends.mla.flashattn_mla.FlashAttnMLABackend" PALLAS = "vllm.v1.attention.backends.pallas.PallasAttentionBackend" IPEX = "vllm.v1.attention.backends.ipex.IpexAttentionBackend" NO_ATTENTION = "vllm.v1.attention.backends.no_attention.NoAttentionBackend" FLEX_ATTENTION = "vllm.v1.attention.backends.flex_attention.FlexAttentionBackend" TREE_ATTN = "vllm.v1.attention.backends.tree_attn.TreeAttentionBackend" ROCM_AITER_UNIFIED_ATTN = ( "vllm.v1.attention.backends.rocm_aiter_unified_attn." "RocmAiterUnifiedAttentionBackend" ) CPU_ATTN = "vllm.v1.attention.backends.cpu_attn.CPUAttentionBackend" # Placeholder for third-party/custom backends - must be registered before use CUSTOM = "" def get_path(self, include_classname: bool = True) -> str: """Get the class path for this backend (respects overrides). Returns: The fully qualified class path string Raises: ValueError: If Backend.CUSTOM is used without being registered """ path = _ATTN_OVERRIDES.get(self, self.value) if not path: raise ValueError( f"Backend {self.name} must be registered before use. " f"Use register_backend(Backend.{self.name}, 'your.module.YourClass')" ) if not include_classname: path = path.rsplit(".", 1)[0] return path def get_class(self) -> "type[AttentionBackend]": """Get the backend class (respects overrides). Returns: The backend class Raises: ImportError: If the backend class cannot be imported ValueError: If Backend.CUSTOM is used without being registered """ return resolve_obj_by_qualname(self.get_path()) def is_overridden(self) -> bool: """Check if this backend has been overridden. Returns: True if the backend has a registered override """ return self in _ATTN_OVERRIDES def clear_override(self) -> None: """Clear any override for this backend, reverting to the default.""" _ATTN_OVERRIDES.pop(self, None) class MambaAttentionBackendEnum(Enum, metaclass=_AttentionBackendEnumMeta): """Enumeration of all supported mamba attention backends. The enum value is the default class path, but this can be overridden at runtime using register_backend(). To get the actual backend class (respecting overrides), use: backend.get_class() """ MAMBA1 = "vllm.v1.attention.backends.mamba1_attn.Mamba1AttentionBackend" MAMBA2 = "vllm.v1.attention.backends.mamba2_attn.Mamba2AttentionBackend" SHORT_CONV = "vllm.v1.attention.backends.short_conv_attn.ShortConvAttentionBackend" LINEAR = "vllm.v1.attention.backends.linear_attn.LinearAttentionBackend" GDN_ATTN = "vllm.v1.attention.backends.gdn_attn.GDNAttentionBackend" # Placeholder for third-party/custom backends - must be registered before use CUSTOM = "" def get_path(self, include_classname: bool = True) -> str: """Get the class path for this backend (respects overrides). Returns: The fully qualified class path string Raises: ValueError: If Backend.CUSTOM is used without being registered """ path = _MAMBA_ATTN_OVERRIDES.get(self, self.value) if not path: raise ValueError( f"Backend {self.name} must be registered before use. " f"Use register_backend(Backend.{self.name}, 'your.module.YourClass')" ) if not include_classname: path = path.rsplit(".", 1)[0] return path def get_class(self) -> "type[AttentionBackend]": """Get the backend class (respects overrides). Returns: The backend class Raises: ImportError: If the backend class cannot be imported ValueError: If Backend.CUSTOM is used without being registered """ return resolve_obj_by_qualname(self.get_path()) def is_overridden(self) -> bool: """Check if this backend has been overridden. Returns: True if the backend has a registered override """ return self in _MAMBA_ATTN_OVERRIDES def clear_override(self) -> None: """Clear any override for this backend, reverting to the default.""" _MAMBA_ATTN_OVERRIDES.pop(self, None) MAMBA_TYPE_TO_BACKEND_MAP = { "mamba1": MambaAttentionBackendEnum.MAMBA1.name, "mamba2": MambaAttentionBackendEnum.MAMBA2.name, "short_conv": MambaAttentionBackendEnum.SHORT_CONV.name, "linear_attention": MambaAttentionBackendEnum.LINEAR.name, "gdn_attention": MambaAttentionBackendEnum.GDN_ATTN.name, "custom": MambaAttentionBackendEnum.CUSTOM.name, } _ATTN_OVERRIDES: dict[AttentionBackendEnum, str] = {} _MAMBA_ATTN_OVERRIDES: dict[MambaAttentionBackendEnum, str] = {} def register_backend( backend: AttentionBackendEnum | MambaAttentionBackendEnum, is_mamba: bool = False, class_path: str | None = None, ) -> Callable[[type], type]: """Register or override a backend implementation. Args: backend: The AttentionBackendEnum member to register class_path: Optional class path. If not provided and used as decorator, will be auto-generated from the class. Returns: Decorator function if class_path is None, otherwise a no-op Examples: # Override an existing attention backend @register_backend(AttentionBackendEnum.FLASH_ATTN) class MyCustomFlashAttn: ... # Override an existing mamba attention backend @register_backend(MambaAttentionBackendEnum.LINEAR, is_mamba=True) class MyCustomMambaAttn: ... # Register a custom third-party attention backend @register_backend(AttentionBackendEnum.CUSTOM) class MyCustomBackend: ... # Direct registration register_backend( AttentionBackendEnum.CUSTOM, "my.module.MyCustomBackend" ) """ def decorator(cls: type) -> type: if is_mamba: _MAMBA_ATTN_OVERRIDES[backend] = f"{cls.__module__}.{cls.__qualname__}" # type: ignore[index] else: _ATTN_OVERRIDES[backend] = f"{cls.__module__}.{cls.__qualname__}" # type: ignore[index] return cls if class_path is not None: if is_mamba: _MAMBA_ATTN_OVERRIDES[backend] = class_path # type: ignore[index] else: _ATTN_OVERRIDES[backend] = class_path # type: ignore[index] return lambda x: x return decorator # Backwards compatibility alias for plugins class _BackendMeta(type): """Metaclass to provide deprecation warnings when accessing _Backend.""" def __getattribute__(cls, name: str): if name not in ("__class__", "__mro__", "__name__"): logger.warning( "_Backend has been renamed to AttentionBackendEnum. " "Please update your code to use AttentionBackendEnum instead. " "_Backend will be removed in a future release." ) return getattr(AttentionBackendEnum, name) def __getitem__(cls, name: str): logger.warning( "_Backend has been renamed to AttentionBackendEnum. " "Please update your code to use AttentionBackendEnum instead. " "_Backend will be removed in a future release." ) return AttentionBackendEnum[name] class _Backend(metaclass=_BackendMeta): """Deprecated: Use AttentionBackendEnum instead. This class is provided for backwards compatibility with plugins and will be removed in a future release. """ pass