[Platform] Refactor Platform attention backend selection to avoid breakpoint for OOT platform (#30212)

Signed-off-by: Isotr0py <mozf@mail2.sysu.edu.cn>
Signed-off-by: Isotr0py <2037008807@qq.com>
Co-authored-by: gemini-code-assist[bot] <176961590+gemini-code-assist[bot]@users.noreply.github.com>
This commit is contained in:
Isotr0py 2025-12-16 01:36:07 +08:00 committed by GitHub
parent 970713d4a4
commit ec154c36ee
No known key found for this signature in database
GPG Key ID: B5690EEEBB952194
7 changed files with 73 additions and 137 deletions

View File

@ -2,11 +2,11 @@
# SPDX-FileCopyrightText: Copyright contributors to the vLLM project
from functools import cache
from typing import cast, get_args
from typing import NamedTuple, cast, get_args
import torch
from vllm.attention.backends.abstract import AttentionBackend
from vllm.attention.backends.abstract import AttentionBackend, AttentionType
from vllm.attention.backends.registry import (
MAMBA_TYPE_TO_BACKEND_MAP,
MambaAttentionBackendEnum,
@ -18,6 +18,31 @@ from vllm.utils.import_utils import resolve_obj_by_qualname
logger = init_logger(__name__)
class AttentionSelectorConfig(NamedTuple):
head_size: int
dtype: torch.dtype
kv_cache_dtype: CacheDType | None
block_size: int | None
use_mla: bool = False
has_sink: bool = False
use_sparse: bool = False
use_mm_prefix: bool = False
attn_type: str = AttentionType.DECODER
def __repr__(self):
return (
f"AttentionSelectorConfig(head_size={self.head_size}, "
f"dtype={self.dtype}, "
f"kv_cache_dtype={self.kv_cache_dtype}, "
f"block_size={self.block_size}, "
f"use_mla={self.use_mla}, "
f"has_sink={self.has_sink}, "
f"use_sparse={self.use_sparse}, "
f"use_mm_prefix={self.use_mm_prefix}, "
f"attn_type={self.attn_type})"
)
def get_attn_backend(
head_size: int,
dtype: torch.dtype,
@ -43,8 +68,7 @@ def get_attn_backend(
vllm_config = get_current_vllm_config()
backend_enum = vllm_config.attention_config.backend
return _cached_get_attn_backend(
backend=backend_enum,
attn_selector_config = AttentionSelectorConfig(
head_size=head_size,
dtype=dtype,
kv_cache_dtype=cast(CacheDType | None, kv_cache_dtype),
@ -53,36 +77,25 @@ def get_attn_backend(
has_sink=has_sink,
use_sparse=use_sparse,
use_mm_prefix=use_mm_prefix,
attn_type=attn_type,
attn_type=attn_type or AttentionType.DECODER,
)
return _cached_get_attn_backend(
backend=backend_enum,
attn_selector_config=attn_selector_config,
)
@cache
def _cached_get_attn_backend(
backend,
head_size: int,
dtype: torch.dtype,
kv_cache_dtype: CacheDType | None,
block_size: int | None,
use_mla: bool = False,
has_sink: bool = False,
use_sparse: bool = False,
use_mm_prefix: bool = False,
attn_type: str | None = None,
attn_selector_config: AttentionSelectorConfig,
) -> type[AttentionBackend]:
from vllm.platforms import current_platform
attention_cls = current_platform.get_attn_backend_cls(
backend,
head_size,
dtype,
kv_cache_dtype,
block_size,
use_mla,
has_sink,
use_sparse,
use_mm_prefix,
attn_type,
attn_selector_config=attn_selector_config,
)
if not attention_cls:
raise ValueError(

View File

@ -23,6 +23,7 @@ from .interface import CpuArchEnum, Platform, PlatformEnum
logger = init_logger(__name__)
if TYPE_CHECKING:
from vllm.attention.selector import AttentionSelectorConfig
from vllm.config import VllmConfig
else:
VllmConfig = None
@ -126,21 +127,13 @@ class CpuPlatform(Platform):
def get_attn_backend_cls(
cls,
selected_backend: "AttentionBackendEnum",
head_size: int,
dtype: torch.dtype,
kv_cache_dtype: str | None,
block_size: int,
use_mla: bool,
has_sink: bool,
use_sparse: bool,
use_mm_prefix: bool,
attn_type: str | None = None,
attn_selector_config: "AttentionSelectorConfig",
) -> str:
if selected_backend and selected_backend != AttentionBackendEnum.CPU_ATTN:
logger.info("Cannot use %s backend on CPU.", selected_backend)
if use_mla:
if attn_selector_config.use_mla:
raise NotImplementedError("MLA is not supported on CPU.")
if use_sparse:
if attn_selector_config.use_sparse:
raise NotImplementedError("Sparse Attention is not supported on CPU.")
return AttentionBackendEnum.CPU_ATTN.get_path()

View File

@ -14,7 +14,6 @@ from typing_extensions import ParamSpec
# import custom ops, trigger op registration
import vllm._C # noqa
from vllm.attention.backends.abstract import AttentionType
from vllm.attention.backends.registry import AttentionBackendEnum
from vllm.logger import init_logger
from vllm.utils.import_utils import import_pynvml
@ -23,6 +22,7 @@ from vllm.utils.torch_utils import cuda_device_count_stateless
from .interface import DeviceCapability, Platform, PlatformEnum
if TYPE_CHECKING:
from vllm.attention.selector import AttentionSelectorConfig
from vllm.config import VllmConfig
from vllm.config.cache import CacheDType
else:
@ -258,16 +258,8 @@ class CudaPlatformBase(Platform):
@classmethod
def get_valid_backends(
cls,
head_size,
dtype,
kv_cache_dtype,
block_size,
use_mla,
has_sink,
use_sparse,
use_mm_prefix,
device_capability,
attn_type,
device_capability: DeviceCapability,
attn_selector_config: "AttentionSelectorConfig",
) -> tuple[
list[tuple["AttentionBackendEnum", int]],
dict["AttentionBackendEnum", list[str]],
@ -275,21 +267,15 @@ class CudaPlatformBase(Platform):
valid_backends_priorities = []
invalid_reasons = {}
backend_priorities = _get_backend_priorities(use_mla, device_capability)
backend_priorities = _get_backend_priorities(
attn_selector_config.use_mla, device_capability
)
for priority, backend in enumerate(backend_priorities):
try:
backend_class = backend.get_class()
invalid_reasons_i = backend_class.validate_configuration(
head_size,
dtype,
kv_cache_dtype,
block_size,
use_mla,
has_sink,
use_sparse,
use_mm_prefix,
device_capability,
attn_type,
device_capability=device_capability,
**attn_selector_config._asdict(),
)
except ImportError:
invalid_reasons_i = ["ImportError"]
@ -304,37 +290,19 @@ class CudaPlatformBase(Platform):
def get_attn_backend_cls(
cls,
selected_backend: "AttentionBackendEnum",
head_size: int,
dtype: torch.dtype,
kv_cache_dtype: "CacheDType | None",
block_size: int | None,
use_mla: bool,
has_sink: bool,
use_sparse: bool,
use_mm_prefix: bool,
attn_type: str | None = None,
attn_selector_config: "AttentionSelectorConfig",
) -> str:
if attn_type is None:
attn_type = AttentionType.DECODER
device_capability = cls.get_device_capability()
assert device_capability is not None
attn_selector_config = attn_selector_config._replace(block_size=None)
# First try checking just the selected backend, if there is one.
if selected_backend is not None:
try:
backend_class = selected_backend.get_class()
invalid_reasons = backend_class.validate_configuration(
head_size,
dtype,
kv_cache_dtype,
None,
use_mla,
has_sink,
use_sparse,
use_mm_prefix,
device_capability,
attn_type,
device_capability=device_capability,
**attn_selector_config._asdict(),
)
except ImportError:
invalid_reasons = ["ImportError"]
@ -350,16 +318,8 @@ class CudaPlatformBase(Platform):
# No selected backend or the selected backend is invalid,
# so we try finding a valid backend.
valid_backends_priorities, invalid_reasons = cls.get_valid_backends(
head_size,
dtype,
kv_cache_dtype,
None,
use_mla,
has_sink,
use_sparse,
use_mm_prefix,
device_capability,
attn_type,
device_capability=device_capability,
attn_selector_config=attn_selector_config,
)
reasons_str = (
"{"
@ -369,11 +329,7 @@ class CudaPlatformBase(Platform):
)
+ "}"
)
config_str = (
f"head_size: {head_size}, dtype: {dtype}, "
f"kv_cache_dtype: {kv_cache_dtype}, block_size: {block_size}, "
f"use_mla: {use_mla}, has_sink: {has_sink}, use_sparse: {use_sparse}"
)
config_str = attn_selector_config.__repr__()
logger.debug_once(
f"Some attention backends are not valid for {cls.device_name} with "
f"{config_str}. Reasons: {reasons_str}."

View File

@ -18,8 +18,8 @@ from vllm.logger import init_logger
if TYPE_CHECKING:
from torch.distributed import PrefixStore, ProcessGroup
from vllm.attention.selector import AttentionSelectorConfig
from vllm.config import VllmConfig
from vllm.config.cache import CacheDType
from vllm.inputs import ProcessorInputs, PromptType
from vllm.pooling_params import PoolingParams
from vllm.sampling_params import SamplingParams
@ -226,15 +226,7 @@ class Platform:
def get_attn_backend_cls(
cls,
selected_backend: "AttentionBackendEnum",
head_size: int,
dtype: torch.dtype,
kv_cache_dtype: "CacheDType | None",
block_size: int,
use_mla: bool,
has_sink: bool,
use_sparse: bool,
use_mm_prefix: bool,
attn_type: str | None = None,
attn_selector_config: "AttentionSelectorConfig",
) -> str:
"""Get the attention backend class of a device."""
return ""

View File

@ -15,6 +15,7 @@ from vllm.utils.torch_utils import cuda_device_count_stateless
from .interface import DeviceCapability, Platform, PlatformEnum
if TYPE_CHECKING:
from vllm.attention.selector import AttentionSelectorConfig
from vllm.config import VllmConfig
logger = init_logger(__name__)
@ -190,21 +191,16 @@ class RocmPlatform(Platform):
@classmethod
def get_attn_backend_cls(
cls,
selected_backend,
head_size,
dtype,
kv_cache_dtype,
block_size,
use_mla,
has_sink,
use_sparse,
use_mm_prefix,
attn_type: str | None = None,
selected_backend: "AttentionBackendEnum",
attn_selector_config: "AttentionSelectorConfig",
) -> str:
from vllm._aiter_ops import rocm_aiter_ops
if use_sparse:
if kv_cache_dtype.startswith("fp8"):
block_size = attn_selector_config.block_size
kv_cache_dtype = attn_selector_config.kv_cache_dtype
if attn_selector_config.use_sparse:
if kv_cache_dtype and kv_cache_dtype.startswith("fp8"):
raise ValueError(
"ROCMAiterMLASparseBackend doesn't support fp8 kv_cache_dtype."
)
@ -214,7 +210,7 @@ class RocmPlatform(Platform):
logger.info_once("Using Sparse MLA backend on V1 engine.")
return AttentionBackendEnum.ROCM_AITER_MLA_SPARSE.get_path()
if use_mla:
if attn_selector_config.use_mla:
if selected_backend is None:
selected_backend = (
AttentionBackendEnum.ROCM_AITER_MLA

View File

@ -16,6 +16,7 @@ from .interface import Platform, PlatformEnum
if TYPE_CHECKING:
from typing import TypeAlias
from vllm.attention.selector import AttentionSelectorConfig
from vllm.config import VllmConfig
from vllm.config.cache import BlockSize
from vllm.pooling_params import PoolingParams
@ -57,17 +58,9 @@ class TpuPlatform(Platform):
def get_attn_backend_cls(
cls,
selected_backend: "AttentionBackendEnum",
head_size: int,
dtype: torch.dtype,
kv_cache_dtype: str | None,
block_size: int,
use_mla: bool,
has_sink: bool,
use_sparse: bool,
use_mm_prefix: bool,
attn_type: str | None = None,
attn_selector_config: "AttentionSelectorConfig",
) -> str:
if use_sparse:
if attn_selector_config.use_sparse:
raise NotImplementedError("Sparse Attention is not supported on TPU.")
if selected_backend != AttentionBackendEnum.PALLAS:
logger.info("Cannot use %s backend on TPU.", selected_backend)

View File

@ -14,6 +14,7 @@ from vllm.logger import init_logger
from .interface import DeviceCapability, Platform, PlatformEnum
if TYPE_CHECKING:
from vllm.attention.selector import AttentionSelectorConfig
from vllm.config import VllmConfig
else:
VllmConfig = None
@ -42,15 +43,7 @@ class XPUPlatform(Platform):
def get_attn_backend_cls(
cls,
selected_backend: "AttentionBackendEnum",
head_size: int,
dtype: torch.dtype,
kv_cache_dtype: str | None,
block_size: int,
use_mla: bool,
has_sink: bool,
use_sparse: bool,
use_mm_prefix: bool,
attn_type: str | None = None,
attn_selector_config: "AttentionSelectorConfig",
) -> str:
from vllm.v1.attention.backends.utils import set_kv_cache_layout
@ -60,7 +53,7 @@ class XPUPlatform(Platform):
"only NHD layout is supported by XPU attention kernels."
)
if use_sparse:
if attn_selector_config.use_sparse:
raise NotImplementedError("Sparse Attention is not supported on XPU.")
if selected_backend == AttentionBackendEnum.TRITON_ATTN:
logger.info_once("Using Triton backend.")
@ -71,7 +64,7 @@ class XPUPlatform(Platform):
elif selected_backend:
raise ValueError(
f"Invalid attention backend for {cls.device_name}, "
f"with use_mla: {use_mla}"
f"with use_mla: {attn_selector_config.use_mla}"
)
logger.info("Using Flash Attention backend.")