From ec154c36ee74f35def28e4ddc1c16a0dc7a8c112 Mon Sep 17 00:00:00 2001 From: Isotr0py Date: Tue, 16 Dec 2025 01:36:07 +0800 Subject: [PATCH] [Platform] Refactor Platform attention backend selection to avoid breakpoint for OOT platform (#30212) Signed-off-by: Isotr0py Signed-off-by: Isotr0py <2037008807@qq.com> Co-authored-by: gemini-code-assist[bot] <176961590+gemini-code-assist[bot]@users.noreply.github.com> --- vllm/attention/selector.py | 59 +++++++++++++++++------------ vllm/platforms/cpu.py | 15 ++------ vllm/platforms/cuda.py | 74 ++++++++----------------------------- vllm/platforms/interface.py | 12 +----- vllm/platforms/rocm.py | 22 +++++------ vllm/platforms/tpu.py | 13 ++----- vllm/platforms/xpu.py | 15 ++------ 7 files changed, 73 insertions(+), 137 deletions(-) diff --git a/vllm/attention/selector.py b/vllm/attention/selector.py index bbf95ff009001..e66f698add99d 100644 --- a/vllm/attention/selector.py +++ b/vllm/attention/selector.py @@ -2,11 +2,11 @@ # SPDX-FileCopyrightText: Copyright contributors to the vLLM project from functools import cache -from typing import cast, get_args +from typing import NamedTuple, cast, get_args import torch -from vllm.attention.backends.abstract import AttentionBackend +from vllm.attention.backends.abstract import AttentionBackend, AttentionType from vllm.attention.backends.registry import ( MAMBA_TYPE_TO_BACKEND_MAP, MambaAttentionBackendEnum, @@ -18,6 +18,31 @@ from vllm.utils.import_utils import resolve_obj_by_qualname logger = init_logger(__name__) +class AttentionSelectorConfig(NamedTuple): + head_size: int + dtype: torch.dtype + kv_cache_dtype: CacheDType | None + block_size: int | None + use_mla: bool = False + has_sink: bool = False + use_sparse: bool = False + use_mm_prefix: bool = False + attn_type: str = AttentionType.DECODER + + def __repr__(self): + return ( + f"AttentionSelectorConfig(head_size={self.head_size}, " + f"dtype={self.dtype}, " + f"kv_cache_dtype={self.kv_cache_dtype}, " + f"block_size={self.block_size}, " + f"use_mla={self.use_mla}, " + f"has_sink={self.has_sink}, " + f"use_sparse={self.use_sparse}, " + f"use_mm_prefix={self.use_mm_prefix}, " + f"attn_type={self.attn_type})" + ) + + def get_attn_backend( head_size: int, dtype: torch.dtype, @@ -43,8 +68,7 @@ def get_attn_backend( vllm_config = get_current_vllm_config() backend_enum = vllm_config.attention_config.backend - return _cached_get_attn_backend( - backend=backend_enum, + attn_selector_config = AttentionSelectorConfig( head_size=head_size, dtype=dtype, kv_cache_dtype=cast(CacheDType | None, kv_cache_dtype), @@ -53,36 +77,25 @@ def get_attn_backend( has_sink=has_sink, use_sparse=use_sparse, use_mm_prefix=use_mm_prefix, - attn_type=attn_type, + attn_type=attn_type or AttentionType.DECODER, + ) + + return _cached_get_attn_backend( + backend=backend_enum, + attn_selector_config=attn_selector_config, ) @cache def _cached_get_attn_backend( backend, - head_size: int, - dtype: torch.dtype, - kv_cache_dtype: CacheDType | None, - block_size: int | None, - use_mla: bool = False, - has_sink: bool = False, - use_sparse: bool = False, - use_mm_prefix: bool = False, - attn_type: str | None = None, + attn_selector_config: AttentionSelectorConfig, ) -> type[AttentionBackend]: from vllm.platforms import current_platform attention_cls = current_platform.get_attn_backend_cls( backend, - head_size, - dtype, - kv_cache_dtype, - block_size, - use_mla, - has_sink, - use_sparse, - use_mm_prefix, - attn_type, + attn_selector_config=attn_selector_config, ) if not attention_cls: raise ValueError( diff --git a/vllm/platforms/cpu.py b/vllm/platforms/cpu.py index d961dcf13e53e..e1b461d79a655 100644 --- a/vllm/platforms/cpu.py +++ b/vllm/platforms/cpu.py @@ -23,6 +23,7 @@ from .interface import CpuArchEnum, Platform, PlatformEnum logger = init_logger(__name__) if TYPE_CHECKING: + from vllm.attention.selector import AttentionSelectorConfig from vllm.config import VllmConfig else: VllmConfig = None @@ -126,21 +127,13 @@ class CpuPlatform(Platform): def get_attn_backend_cls( cls, selected_backend: "AttentionBackendEnum", - head_size: int, - dtype: torch.dtype, - kv_cache_dtype: str | None, - block_size: int, - use_mla: bool, - has_sink: bool, - use_sparse: bool, - use_mm_prefix: bool, - attn_type: str | None = None, + attn_selector_config: "AttentionSelectorConfig", ) -> str: if selected_backend and selected_backend != AttentionBackendEnum.CPU_ATTN: logger.info("Cannot use %s backend on CPU.", selected_backend) - if use_mla: + if attn_selector_config.use_mla: raise NotImplementedError("MLA is not supported on CPU.") - if use_sparse: + if attn_selector_config.use_sparse: raise NotImplementedError("Sparse Attention is not supported on CPU.") return AttentionBackendEnum.CPU_ATTN.get_path() diff --git a/vllm/platforms/cuda.py b/vllm/platforms/cuda.py index ad5a6789b2023..2dc4ba5d70cac 100644 --- a/vllm/platforms/cuda.py +++ b/vllm/platforms/cuda.py @@ -14,7 +14,6 @@ from typing_extensions import ParamSpec # import custom ops, trigger op registration import vllm._C # noqa -from vllm.attention.backends.abstract import AttentionType from vllm.attention.backends.registry import AttentionBackendEnum from vllm.logger import init_logger from vllm.utils.import_utils import import_pynvml @@ -23,6 +22,7 @@ from vllm.utils.torch_utils import cuda_device_count_stateless from .interface import DeviceCapability, Platform, PlatformEnum if TYPE_CHECKING: + from vllm.attention.selector import AttentionSelectorConfig from vllm.config import VllmConfig from vllm.config.cache import CacheDType else: @@ -258,16 +258,8 @@ class CudaPlatformBase(Platform): @classmethod def get_valid_backends( cls, - head_size, - dtype, - kv_cache_dtype, - block_size, - use_mla, - has_sink, - use_sparse, - use_mm_prefix, - device_capability, - attn_type, + device_capability: DeviceCapability, + attn_selector_config: "AttentionSelectorConfig", ) -> tuple[ list[tuple["AttentionBackendEnum", int]], dict["AttentionBackendEnum", list[str]], @@ -275,21 +267,15 @@ class CudaPlatformBase(Platform): valid_backends_priorities = [] invalid_reasons = {} - backend_priorities = _get_backend_priorities(use_mla, device_capability) + backend_priorities = _get_backend_priorities( + attn_selector_config.use_mla, device_capability + ) for priority, backend in enumerate(backend_priorities): try: backend_class = backend.get_class() invalid_reasons_i = backend_class.validate_configuration( - head_size, - dtype, - kv_cache_dtype, - block_size, - use_mla, - has_sink, - use_sparse, - use_mm_prefix, - device_capability, - attn_type, + device_capability=device_capability, + **attn_selector_config._asdict(), ) except ImportError: invalid_reasons_i = ["ImportError"] @@ -304,37 +290,19 @@ class CudaPlatformBase(Platform): def get_attn_backend_cls( cls, selected_backend: "AttentionBackendEnum", - head_size: int, - dtype: torch.dtype, - kv_cache_dtype: "CacheDType | None", - block_size: int | None, - use_mla: bool, - has_sink: bool, - use_sparse: bool, - use_mm_prefix: bool, - attn_type: str | None = None, + attn_selector_config: "AttentionSelectorConfig", ) -> str: - if attn_type is None: - attn_type = AttentionType.DECODER - device_capability = cls.get_device_capability() assert device_capability is not None + attn_selector_config = attn_selector_config._replace(block_size=None) # First try checking just the selected backend, if there is one. if selected_backend is not None: try: backend_class = selected_backend.get_class() invalid_reasons = backend_class.validate_configuration( - head_size, - dtype, - kv_cache_dtype, - None, - use_mla, - has_sink, - use_sparse, - use_mm_prefix, - device_capability, - attn_type, + device_capability=device_capability, + **attn_selector_config._asdict(), ) except ImportError: invalid_reasons = ["ImportError"] @@ -350,16 +318,8 @@ class CudaPlatformBase(Platform): # No selected backend or the selected backend is invalid, # so we try finding a valid backend. valid_backends_priorities, invalid_reasons = cls.get_valid_backends( - head_size, - dtype, - kv_cache_dtype, - None, - use_mla, - has_sink, - use_sparse, - use_mm_prefix, - device_capability, - attn_type, + device_capability=device_capability, + attn_selector_config=attn_selector_config, ) reasons_str = ( "{" @@ -369,11 +329,7 @@ class CudaPlatformBase(Platform): ) + "}" ) - config_str = ( - f"head_size: {head_size}, dtype: {dtype}, " - f"kv_cache_dtype: {kv_cache_dtype}, block_size: {block_size}, " - f"use_mla: {use_mla}, has_sink: {has_sink}, use_sparse: {use_sparse}" - ) + config_str = attn_selector_config.__repr__() logger.debug_once( f"Some attention backends are not valid for {cls.device_name} with " f"{config_str}. Reasons: {reasons_str}." diff --git a/vllm/platforms/interface.py b/vllm/platforms/interface.py index 9788e5b564165..d4b40045df384 100644 --- a/vllm/platforms/interface.py +++ b/vllm/platforms/interface.py @@ -18,8 +18,8 @@ from vllm.logger import init_logger if TYPE_CHECKING: from torch.distributed import PrefixStore, ProcessGroup + from vllm.attention.selector import AttentionSelectorConfig from vllm.config import VllmConfig - from vllm.config.cache import CacheDType from vllm.inputs import ProcessorInputs, PromptType from vllm.pooling_params import PoolingParams from vllm.sampling_params import SamplingParams @@ -226,15 +226,7 @@ class Platform: def get_attn_backend_cls( cls, selected_backend: "AttentionBackendEnum", - head_size: int, - dtype: torch.dtype, - kv_cache_dtype: "CacheDType | None", - block_size: int, - use_mla: bool, - has_sink: bool, - use_sparse: bool, - use_mm_prefix: bool, - attn_type: str | None = None, + attn_selector_config: "AttentionSelectorConfig", ) -> str: """Get the attention backend class of a device.""" return "" diff --git a/vllm/platforms/rocm.py b/vllm/platforms/rocm.py index b90fb3686c280..e469a928da229 100644 --- a/vllm/platforms/rocm.py +++ b/vllm/platforms/rocm.py @@ -15,6 +15,7 @@ from vllm.utils.torch_utils import cuda_device_count_stateless from .interface import DeviceCapability, Platform, PlatformEnum if TYPE_CHECKING: + from vllm.attention.selector import AttentionSelectorConfig from vllm.config import VllmConfig logger = init_logger(__name__) @@ -190,21 +191,16 @@ class RocmPlatform(Platform): @classmethod def get_attn_backend_cls( cls, - selected_backend, - head_size, - dtype, - kv_cache_dtype, - block_size, - use_mla, - has_sink, - use_sparse, - use_mm_prefix, - attn_type: str | None = None, + selected_backend: "AttentionBackendEnum", + attn_selector_config: "AttentionSelectorConfig", ) -> str: from vllm._aiter_ops import rocm_aiter_ops - if use_sparse: - if kv_cache_dtype.startswith("fp8"): + block_size = attn_selector_config.block_size + kv_cache_dtype = attn_selector_config.kv_cache_dtype + + if attn_selector_config.use_sparse: + if kv_cache_dtype and kv_cache_dtype.startswith("fp8"): raise ValueError( "ROCMAiterMLASparseBackend doesn't support fp8 kv_cache_dtype." ) @@ -214,7 +210,7 @@ class RocmPlatform(Platform): logger.info_once("Using Sparse MLA backend on V1 engine.") return AttentionBackendEnum.ROCM_AITER_MLA_SPARSE.get_path() - if use_mla: + if attn_selector_config.use_mla: if selected_backend is None: selected_backend = ( AttentionBackendEnum.ROCM_AITER_MLA diff --git a/vllm/platforms/tpu.py b/vllm/platforms/tpu.py index 50de87098f05c..7c479bf2b6a0e 100644 --- a/vllm/platforms/tpu.py +++ b/vllm/platforms/tpu.py @@ -16,6 +16,7 @@ from .interface import Platform, PlatformEnum if TYPE_CHECKING: from typing import TypeAlias + from vllm.attention.selector import AttentionSelectorConfig from vllm.config import VllmConfig from vllm.config.cache import BlockSize from vllm.pooling_params import PoolingParams @@ -57,17 +58,9 @@ class TpuPlatform(Platform): def get_attn_backend_cls( cls, selected_backend: "AttentionBackendEnum", - head_size: int, - dtype: torch.dtype, - kv_cache_dtype: str | None, - block_size: int, - use_mla: bool, - has_sink: bool, - use_sparse: bool, - use_mm_prefix: bool, - attn_type: str | None = None, + attn_selector_config: "AttentionSelectorConfig", ) -> str: - if use_sparse: + if attn_selector_config.use_sparse: raise NotImplementedError("Sparse Attention is not supported on TPU.") if selected_backend != AttentionBackendEnum.PALLAS: logger.info("Cannot use %s backend on TPU.", selected_backend) diff --git a/vllm/platforms/xpu.py b/vllm/platforms/xpu.py index c1ec2d41c73b0..af8979af36643 100644 --- a/vllm/platforms/xpu.py +++ b/vllm/platforms/xpu.py @@ -14,6 +14,7 @@ from vllm.logger import init_logger from .interface import DeviceCapability, Platform, PlatformEnum if TYPE_CHECKING: + from vllm.attention.selector import AttentionSelectorConfig from vllm.config import VllmConfig else: VllmConfig = None @@ -42,15 +43,7 @@ class XPUPlatform(Platform): def get_attn_backend_cls( cls, selected_backend: "AttentionBackendEnum", - head_size: int, - dtype: torch.dtype, - kv_cache_dtype: str | None, - block_size: int, - use_mla: bool, - has_sink: bool, - use_sparse: bool, - use_mm_prefix: bool, - attn_type: str | None = None, + attn_selector_config: "AttentionSelectorConfig", ) -> str: from vllm.v1.attention.backends.utils import set_kv_cache_layout @@ -60,7 +53,7 @@ class XPUPlatform(Platform): "only NHD layout is supported by XPU attention kernels." ) - if use_sparse: + if attn_selector_config.use_sparse: raise NotImplementedError("Sparse Attention is not supported on XPU.") if selected_backend == AttentionBackendEnum.TRITON_ATTN: logger.info_once("Using Triton backend.") @@ -71,7 +64,7 @@ class XPUPlatform(Platform): elif selected_backend: raise ValueError( f"Invalid attention backend for {cls.device_name}, " - f"with use_mla: {use_mla}" + f"with use_mla: {attn_selector_config.use_mla}" ) logger.info("Using Flash Attention backend.")