[Platform] Refactor Platform attention backend selection to avoid breakpoint for OOT platform (#30212)

Signed-off-by: Isotr0py <mozf@mail2.sysu.edu.cn>
Signed-off-by: Isotr0py <2037008807@qq.com>
Co-authored-by: gemini-code-assist[bot] <176961590+gemini-code-assist[bot]@users.noreply.github.com>
This commit is contained in:
Isotr0py 2025-12-16 01:36:07 +08:00 committed by GitHub
parent 970713d4a4
commit ec154c36ee
No known key found for this signature in database
GPG Key ID: B5690EEEBB952194
7 changed files with 73 additions and 137 deletions

View File

@ -2,11 +2,11 @@
# SPDX-FileCopyrightText: Copyright contributors to the vLLM project # SPDX-FileCopyrightText: Copyright contributors to the vLLM project
from functools import cache from functools import cache
from typing import cast, get_args from typing import NamedTuple, cast, get_args
import torch import torch
from vllm.attention.backends.abstract import AttentionBackend from vllm.attention.backends.abstract import AttentionBackend, AttentionType
from vllm.attention.backends.registry import ( from vllm.attention.backends.registry import (
MAMBA_TYPE_TO_BACKEND_MAP, MAMBA_TYPE_TO_BACKEND_MAP,
MambaAttentionBackendEnum, MambaAttentionBackendEnum,
@ -18,6 +18,31 @@ from vllm.utils.import_utils import resolve_obj_by_qualname
logger = init_logger(__name__) logger = init_logger(__name__)
class AttentionSelectorConfig(NamedTuple):
head_size: int
dtype: torch.dtype
kv_cache_dtype: CacheDType | None
block_size: int | None
use_mla: bool = False
has_sink: bool = False
use_sparse: bool = False
use_mm_prefix: bool = False
attn_type: str = AttentionType.DECODER
def __repr__(self):
return (
f"AttentionSelectorConfig(head_size={self.head_size}, "
f"dtype={self.dtype}, "
f"kv_cache_dtype={self.kv_cache_dtype}, "
f"block_size={self.block_size}, "
f"use_mla={self.use_mla}, "
f"has_sink={self.has_sink}, "
f"use_sparse={self.use_sparse}, "
f"use_mm_prefix={self.use_mm_prefix}, "
f"attn_type={self.attn_type})"
)
def get_attn_backend( def get_attn_backend(
head_size: int, head_size: int,
dtype: torch.dtype, dtype: torch.dtype,
@ -43,8 +68,7 @@ def get_attn_backend(
vllm_config = get_current_vllm_config() vllm_config = get_current_vllm_config()
backend_enum = vllm_config.attention_config.backend backend_enum = vllm_config.attention_config.backend
return _cached_get_attn_backend( attn_selector_config = AttentionSelectorConfig(
backend=backend_enum,
head_size=head_size, head_size=head_size,
dtype=dtype, dtype=dtype,
kv_cache_dtype=cast(CacheDType | None, kv_cache_dtype), kv_cache_dtype=cast(CacheDType | None, kv_cache_dtype),
@ -53,36 +77,25 @@ def get_attn_backend(
has_sink=has_sink, has_sink=has_sink,
use_sparse=use_sparse, use_sparse=use_sparse,
use_mm_prefix=use_mm_prefix, use_mm_prefix=use_mm_prefix,
attn_type=attn_type, attn_type=attn_type or AttentionType.DECODER,
)
return _cached_get_attn_backend(
backend=backend_enum,
attn_selector_config=attn_selector_config,
) )
@cache @cache
def _cached_get_attn_backend( def _cached_get_attn_backend(
backend, backend,
head_size: int, attn_selector_config: AttentionSelectorConfig,
dtype: torch.dtype,
kv_cache_dtype: CacheDType | None,
block_size: int | None,
use_mla: bool = False,
has_sink: bool = False,
use_sparse: bool = False,
use_mm_prefix: bool = False,
attn_type: str | None = None,
) -> type[AttentionBackend]: ) -> type[AttentionBackend]:
from vllm.platforms import current_platform from vllm.platforms import current_platform
attention_cls = current_platform.get_attn_backend_cls( attention_cls = current_platform.get_attn_backend_cls(
backend, backend,
head_size, attn_selector_config=attn_selector_config,
dtype,
kv_cache_dtype,
block_size,
use_mla,
has_sink,
use_sparse,
use_mm_prefix,
attn_type,
) )
if not attention_cls: if not attention_cls:
raise ValueError( raise ValueError(

View File

@ -23,6 +23,7 @@ from .interface import CpuArchEnum, Platform, PlatformEnum
logger = init_logger(__name__) logger = init_logger(__name__)
if TYPE_CHECKING: if TYPE_CHECKING:
from vllm.attention.selector import AttentionSelectorConfig
from vllm.config import VllmConfig from vllm.config import VllmConfig
else: else:
VllmConfig = None VllmConfig = None
@ -126,21 +127,13 @@ class CpuPlatform(Platform):
def get_attn_backend_cls( def get_attn_backend_cls(
cls, cls,
selected_backend: "AttentionBackendEnum", selected_backend: "AttentionBackendEnum",
head_size: int, attn_selector_config: "AttentionSelectorConfig",
dtype: torch.dtype,
kv_cache_dtype: str | None,
block_size: int,
use_mla: bool,
has_sink: bool,
use_sparse: bool,
use_mm_prefix: bool,
attn_type: str | None = None,
) -> str: ) -> str:
if selected_backend and selected_backend != AttentionBackendEnum.CPU_ATTN: if selected_backend and selected_backend != AttentionBackendEnum.CPU_ATTN:
logger.info("Cannot use %s backend on CPU.", selected_backend) logger.info("Cannot use %s backend on CPU.", selected_backend)
if use_mla: if attn_selector_config.use_mla:
raise NotImplementedError("MLA is not supported on CPU.") raise NotImplementedError("MLA is not supported on CPU.")
if use_sparse: if attn_selector_config.use_sparse:
raise NotImplementedError("Sparse Attention is not supported on CPU.") raise NotImplementedError("Sparse Attention is not supported on CPU.")
return AttentionBackendEnum.CPU_ATTN.get_path() return AttentionBackendEnum.CPU_ATTN.get_path()

View File

@ -14,7 +14,6 @@ from typing_extensions import ParamSpec
# import custom ops, trigger op registration # import custom ops, trigger op registration
import vllm._C # noqa import vllm._C # noqa
from vllm.attention.backends.abstract import AttentionType
from vllm.attention.backends.registry import AttentionBackendEnum from vllm.attention.backends.registry import AttentionBackendEnum
from vllm.logger import init_logger from vllm.logger import init_logger
from vllm.utils.import_utils import import_pynvml from vllm.utils.import_utils import import_pynvml
@ -23,6 +22,7 @@ from vllm.utils.torch_utils import cuda_device_count_stateless
from .interface import DeviceCapability, Platform, PlatformEnum from .interface import DeviceCapability, Platform, PlatformEnum
if TYPE_CHECKING: if TYPE_CHECKING:
from vllm.attention.selector import AttentionSelectorConfig
from vllm.config import VllmConfig from vllm.config import VllmConfig
from vllm.config.cache import CacheDType from vllm.config.cache import CacheDType
else: else:
@ -258,16 +258,8 @@ class CudaPlatformBase(Platform):
@classmethod @classmethod
def get_valid_backends( def get_valid_backends(
cls, cls,
head_size, device_capability: DeviceCapability,
dtype, attn_selector_config: "AttentionSelectorConfig",
kv_cache_dtype,
block_size,
use_mla,
has_sink,
use_sparse,
use_mm_prefix,
device_capability,
attn_type,
) -> tuple[ ) -> tuple[
list[tuple["AttentionBackendEnum", int]], list[tuple["AttentionBackendEnum", int]],
dict["AttentionBackendEnum", list[str]], dict["AttentionBackendEnum", list[str]],
@ -275,21 +267,15 @@ class CudaPlatformBase(Platform):
valid_backends_priorities = [] valid_backends_priorities = []
invalid_reasons = {} invalid_reasons = {}
backend_priorities = _get_backend_priorities(use_mla, device_capability) backend_priorities = _get_backend_priorities(
attn_selector_config.use_mla, device_capability
)
for priority, backend in enumerate(backend_priorities): for priority, backend in enumerate(backend_priorities):
try: try:
backend_class = backend.get_class() backend_class = backend.get_class()
invalid_reasons_i = backend_class.validate_configuration( invalid_reasons_i = backend_class.validate_configuration(
head_size, device_capability=device_capability,
dtype, **attn_selector_config._asdict(),
kv_cache_dtype,
block_size,
use_mla,
has_sink,
use_sparse,
use_mm_prefix,
device_capability,
attn_type,
) )
except ImportError: except ImportError:
invalid_reasons_i = ["ImportError"] invalid_reasons_i = ["ImportError"]
@ -304,37 +290,19 @@ class CudaPlatformBase(Platform):
def get_attn_backend_cls( def get_attn_backend_cls(
cls, cls,
selected_backend: "AttentionBackendEnum", selected_backend: "AttentionBackendEnum",
head_size: int, attn_selector_config: "AttentionSelectorConfig",
dtype: torch.dtype,
kv_cache_dtype: "CacheDType | None",
block_size: int | None,
use_mla: bool,
has_sink: bool,
use_sparse: bool,
use_mm_prefix: bool,
attn_type: str | None = None,
) -> str: ) -> str:
if attn_type is None:
attn_type = AttentionType.DECODER
device_capability = cls.get_device_capability() device_capability = cls.get_device_capability()
assert device_capability is not None assert device_capability is not None
attn_selector_config = attn_selector_config._replace(block_size=None)
# First try checking just the selected backend, if there is one. # First try checking just the selected backend, if there is one.
if selected_backend is not None: if selected_backend is not None:
try: try:
backend_class = selected_backend.get_class() backend_class = selected_backend.get_class()
invalid_reasons = backend_class.validate_configuration( invalid_reasons = backend_class.validate_configuration(
head_size, device_capability=device_capability,
dtype, **attn_selector_config._asdict(),
kv_cache_dtype,
None,
use_mla,
has_sink,
use_sparse,
use_mm_prefix,
device_capability,
attn_type,
) )
except ImportError: except ImportError:
invalid_reasons = ["ImportError"] invalid_reasons = ["ImportError"]
@ -350,16 +318,8 @@ class CudaPlatformBase(Platform):
# No selected backend or the selected backend is invalid, # No selected backend or the selected backend is invalid,
# so we try finding a valid backend. # so we try finding a valid backend.
valid_backends_priorities, invalid_reasons = cls.get_valid_backends( valid_backends_priorities, invalid_reasons = cls.get_valid_backends(
head_size, device_capability=device_capability,
dtype, attn_selector_config=attn_selector_config,
kv_cache_dtype,
None,
use_mla,
has_sink,
use_sparse,
use_mm_prefix,
device_capability,
attn_type,
) )
reasons_str = ( reasons_str = (
"{" "{"
@ -369,11 +329,7 @@ class CudaPlatformBase(Platform):
) )
+ "}" + "}"
) )
config_str = ( config_str = attn_selector_config.__repr__()
f"head_size: {head_size}, dtype: {dtype}, "
f"kv_cache_dtype: {kv_cache_dtype}, block_size: {block_size}, "
f"use_mla: {use_mla}, has_sink: {has_sink}, use_sparse: {use_sparse}"
)
logger.debug_once( logger.debug_once(
f"Some attention backends are not valid for {cls.device_name} with " f"Some attention backends are not valid for {cls.device_name} with "
f"{config_str}. Reasons: {reasons_str}." f"{config_str}. Reasons: {reasons_str}."

View File

@ -18,8 +18,8 @@ from vllm.logger import init_logger
if TYPE_CHECKING: if TYPE_CHECKING:
from torch.distributed import PrefixStore, ProcessGroup from torch.distributed import PrefixStore, ProcessGroup
from vllm.attention.selector import AttentionSelectorConfig
from vllm.config import VllmConfig from vllm.config import VllmConfig
from vllm.config.cache import CacheDType
from vllm.inputs import ProcessorInputs, PromptType from vllm.inputs import ProcessorInputs, PromptType
from vllm.pooling_params import PoolingParams from vllm.pooling_params import PoolingParams
from vllm.sampling_params import SamplingParams from vllm.sampling_params import SamplingParams
@ -226,15 +226,7 @@ class Platform:
def get_attn_backend_cls( def get_attn_backend_cls(
cls, cls,
selected_backend: "AttentionBackendEnum", selected_backend: "AttentionBackendEnum",
head_size: int, attn_selector_config: "AttentionSelectorConfig",
dtype: torch.dtype,
kv_cache_dtype: "CacheDType | None",
block_size: int,
use_mla: bool,
has_sink: bool,
use_sparse: bool,
use_mm_prefix: bool,
attn_type: str | None = None,
) -> str: ) -> str:
"""Get the attention backend class of a device.""" """Get the attention backend class of a device."""
return "" return ""

View File

@ -15,6 +15,7 @@ from vllm.utils.torch_utils import cuda_device_count_stateless
from .interface import DeviceCapability, Platform, PlatformEnum from .interface import DeviceCapability, Platform, PlatformEnum
if TYPE_CHECKING: if TYPE_CHECKING:
from vllm.attention.selector import AttentionSelectorConfig
from vllm.config import VllmConfig from vllm.config import VllmConfig
logger = init_logger(__name__) logger = init_logger(__name__)
@ -190,21 +191,16 @@ class RocmPlatform(Platform):
@classmethod @classmethod
def get_attn_backend_cls( def get_attn_backend_cls(
cls, cls,
selected_backend, selected_backend: "AttentionBackendEnum",
head_size, attn_selector_config: "AttentionSelectorConfig",
dtype,
kv_cache_dtype,
block_size,
use_mla,
has_sink,
use_sparse,
use_mm_prefix,
attn_type: str | None = None,
) -> str: ) -> str:
from vllm._aiter_ops import rocm_aiter_ops from vllm._aiter_ops import rocm_aiter_ops
if use_sparse: block_size = attn_selector_config.block_size
if kv_cache_dtype.startswith("fp8"): kv_cache_dtype = attn_selector_config.kv_cache_dtype
if attn_selector_config.use_sparse:
if kv_cache_dtype and kv_cache_dtype.startswith("fp8"):
raise ValueError( raise ValueError(
"ROCMAiterMLASparseBackend doesn't support fp8 kv_cache_dtype." "ROCMAiterMLASparseBackend doesn't support fp8 kv_cache_dtype."
) )
@ -214,7 +210,7 @@ class RocmPlatform(Platform):
logger.info_once("Using Sparse MLA backend on V1 engine.") logger.info_once("Using Sparse MLA backend on V1 engine.")
return AttentionBackendEnum.ROCM_AITER_MLA_SPARSE.get_path() return AttentionBackendEnum.ROCM_AITER_MLA_SPARSE.get_path()
if use_mla: if attn_selector_config.use_mla:
if selected_backend is None: if selected_backend is None:
selected_backend = ( selected_backend = (
AttentionBackendEnum.ROCM_AITER_MLA AttentionBackendEnum.ROCM_AITER_MLA

View File

@ -16,6 +16,7 @@ from .interface import Platform, PlatformEnum
if TYPE_CHECKING: if TYPE_CHECKING:
from typing import TypeAlias from typing import TypeAlias
from vllm.attention.selector import AttentionSelectorConfig
from vllm.config import VllmConfig from vllm.config import VllmConfig
from vllm.config.cache import BlockSize from vllm.config.cache import BlockSize
from vllm.pooling_params import PoolingParams from vllm.pooling_params import PoolingParams
@ -57,17 +58,9 @@ class TpuPlatform(Platform):
def get_attn_backend_cls( def get_attn_backend_cls(
cls, cls,
selected_backend: "AttentionBackendEnum", selected_backend: "AttentionBackendEnum",
head_size: int, attn_selector_config: "AttentionSelectorConfig",
dtype: torch.dtype,
kv_cache_dtype: str | None,
block_size: int,
use_mla: bool,
has_sink: bool,
use_sparse: bool,
use_mm_prefix: bool,
attn_type: str | None = None,
) -> str: ) -> str:
if use_sparse: if attn_selector_config.use_sparse:
raise NotImplementedError("Sparse Attention is not supported on TPU.") raise NotImplementedError("Sparse Attention is not supported on TPU.")
if selected_backend != AttentionBackendEnum.PALLAS: if selected_backend != AttentionBackendEnum.PALLAS:
logger.info("Cannot use %s backend on TPU.", selected_backend) logger.info("Cannot use %s backend on TPU.", selected_backend)

View File

@ -14,6 +14,7 @@ from vllm.logger import init_logger
from .interface import DeviceCapability, Platform, PlatformEnum from .interface import DeviceCapability, Platform, PlatformEnum
if TYPE_CHECKING: if TYPE_CHECKING:
from vllm.attention.selector import AttentionSelectorConfig
from vllm.config import VllmConfig from vllm.config import VllmConfig
else: else:
VllmConfig = None VllmConfig = None
@ -42,15 +43,7 @@ class XPUPlatform(Platform):
def get_attn_backend_cls( def get_attn_backend_cls(
cls, cls,
selected_backend: "AttentionBackendEnum", selected_backend: "AttentionBackendEnum",
head_size: int, attn_selector_config: "AttentionSelectorConfig",
dtype: torch.dtype,
kv_cache_dtype: str | None,
block_size: int,
use_mla: bool,
has_sink: bool,
use_sparse: bool,
use_mm_prefix: bool,
attn_type: str | None = None,
) -> str: ) -> str:
from vllm.v1.attention.backends.utils import set_kv_cache_layout from vllm.v1.attention.backends.utils import set_kv_cache_layout
@ -60,7 +53,7 @@ class XPUPlatform(Platform):
"only NHD layout is supported by XPU attention kernels." "only NHD layout is supported by XPU attention kernels."
) )
if use_sparse: if attn_selector_config.use_sparse:
raise NotImplementedError("Sparse Attention is not supported on XPU.") raise NotImplementedError("Sparse Attention is not supported on XPU.")
if selected_backend == AttentionBackendEnum.TRITON_ATTN: if selected_backend == AttentionBackendEnum.TRITON_ATTN:
logger.info_once("Using Triton backend.") logger.info_once("Using Triton backend.")
@ -71,7 +64,7 @@ class XPUPlatform(Platform):
elif selected_backend: elif selected_backend:
raise ValueError( raise ValueError(
f"Invalid attention backend for {cls.device_name}, " f"Invalid attention backend for {cls.device_name}, "
f"with use_mla: {use_mla}" f"with use_mla: {attn_selector_config.use_mla}"
) )
logger.info("Using Flash Attention backend.") logger.info("Using Flash Attention backend.")