mirror of
https://git.datalinker.icu/vllm-project/vllm.git
synced 2026-01-01 17:30:48 +08:00
[Platform] Refactor Platform attention backend selection to avoid breakpoint for OOT platform (#30212)
Signed-off-by: Isotr0py <mozf@mail2.sysu.edu.cn> Signed-off-by: Isotr0py <2037008807@qq.com> Co-authored-by: gemini-code-assist[bot] <176961590+gemini-code-assist[bot]@users.noreply.github.com>
This commit is contained in:
parent
970713d4a4
commit
ec154c36ee
@ -2,11 +2,11 @@
|
||||
# SPDX-FileCopyrightText: Copyright contributors to the vLLM project
|
||||
|
||||
from functools import cache
|
||||
from typing import cast, get_args
|
||||
from typing import NamedTuple, cast, get_args
|
||||
|
||||
import torch
|
||||
|
||||
from vllm.attention.backends.abstract import AttentionBackend
|
||||
from vllm.attention.backends.abstract import AttentionBackend, AttentionType
|
||||
from vllm.attention.backends.registry import (
|
||||
MAMBA_TYPE_TO_BACKEND_MAP,
|
||||
MambaAttentionBackendEnum,
|
||||
@ -18,6 +18,31 @@ from vllm.utils.import_utils import resolve_obj_by_qualname
|
||||
logger = init_logger(__name__)
|
||||
|
||||
|
||||
class AttentionSelectorConfig(NamedTuple):
|
||||
head_size: int
|
||||
dtype: torch.dtype
|
||||
kv_cache_dtype: CacheDType | None
|
||||
block_size: int | None
|
||||
use_mla: bool = False
|
||||
has_sink: bool = False
|
||||
use_sparse: bool = False
|
||||
use_mm_prefix: bool = False
|
||||
attn_type: str = AttentionType.DECODER
|
||||
|
||||
def __repr__(self):
|
||||
return (
|
||||
f"AttentionSelectorConfig(head_size={self.head_size}, "
|
||||
f"dtype={self.dtype}, "
|
||||
f"kv_cache_dtype={self.kv_cache_dtype}, "
|
||||
f"block_size={self.block_size}, "
|
||||
f"use_mla={self.use_mla}, "
|
||||
f"has_sink={self.has_sink}, "
|
||||
f"use_sparse={self.use_sparse}, "
|
||||
f"use_mm_prefix={self.use_mm_prefix}, "
|
||||
f"attn_type={self.attn_type})"
|
||||
)
|
||||
|
||||
|
||||
def get_attn_backend(
|
||||
head_size: int,
|
||||
dtype: torch.dtype,
|
||||
@ -43,8 +68,7 @@ def get_attn_backend(
|
||||
vllm_config = get_current_vllm_config()
|
||||
backend_enum = vllm_config.attention_config.backend
|
||||
|
||||
return _cached_get_attn_backend(
|
||||
backend=backend_enum,
|
||||
attn_selector_config = AttentionSelectorConfig(
|
||||
head_size=head_size,
|
||||
dtype=dtype,
|
||||
kv_cache_dtype=cast(CacheDType | None, kv_cache_dtype),
|
||||
@ -53,36 +77,25 @@ def get_attn_backend(
|
||||
has_sink=has_sink,
|
||||
use_sparse=use_sparse,
|
||||
use_mm_prefix=use_mm_prefix,
|
||||
attn_type=attn_type,
|
||||
attn_type=attn_type or AttentionType.DECODER,
|
||||
)
|
||||
|
||||
return _cached_get_attn_backend(
|
||||
backend=backend_enum,
|
||||
attn_selector_config=attn_selector_config,
|
||||
)
|
||||
|
||||
|
||||
@cache
|
||||
def _cached_get_attn_backend(
|
||||
backend,
|
||||
head_size: int,
|
||||
dtype: torch.dtype,
|
||||
kv_cache_dtype: CacheDType | None,
|
||||
block_size: int | None,
|
||||
use_mla: bool = False,
|
||||
has_sink: bool = False,
|
||||
use_sparse: bool = False,
|
||||
use_mm_prefix: bool = False,
|
||||
attn_type: str | None = None,
|
||||
attn_selector_config: AttentionSelectorConfig,
|
||||
) -> type[AttentionBackend]:
|
||||
from vllm.platforms import current_platform
|
||||
|
||||
attention_cls = current_platform.get_attn_backend_cls(
|
||||
backend,
|
||||
head_size,
|
||||
dtype,
|
||||
kv_cache_dtype,
|
||||
block_size,
|
||||
use_mla,
|
||||
has_sink,
|
||||
use_sparse,
|
||||
use_mm_prefix,
|
||||
attn_type,
|
||||
attn_selector_config=attn_selector_config,
|
||||
)
|
||||
if not attention_cls:
|
||||
raise ValueError(
|
||||
|
||||
@ -23,6 +23,7 @@ from .interface import CpuArchEnum, Platform, PlatformEnum
|
||||
logger = init_logger(__name__)
|
||||
|
||||
if TYPE_CHECKING:
|
||||
from vllm.attention.selector import AttentionSelectorConfig
|
||||
from vllm.config import VllmConfig
|
||||
else:
|
||||
VllmConfig = None
|
||||
@ -126,21 +127,13 @@ class CpuPlatform(Platform):
|
||||
def get_attn_backend_cls(
|
||||
cls,
|
||||
selected_backend: "AttentionBackendEnum",
|
||||
head_size: int,
|
||||
dtype: torch.dtype,
|
||||
kv_cache_dtype: str | None,
|
||||
block_size: int,
|
||||
use_mla: bool,
|
||||
has_sink: bool,
|
||||
use_sparse: bool,
|
||||
use_mm_prefix: bool,
|
||||
attn_type: str | None = None,
|
||||
attn_selector_config: "AttentionSelectorConfig",
|
||||
) -> str:
|
||||
if selected_backend and selected_backend != AttentionBackendEnum.CPU_ATTN:
|
||||
logger.info("Cannot use %s backend on CPU.", selected_backend)
|
||||
if use_mla:
|
||||
if attn_selector_config.use_mla:
|
||||
raise NotImplementedError("MLA is not supported on CPU.")
|
||||
if use_sparse:
|
||||
if attn_selector_config.use_sparse:
|
||||
raise NotImplementedError("Sparse Attention is not supported on CPU.")
|
||||
return AttentionBackendEnum.CPU_ATTN.get_path()
|
||||
|
||||
|
||||
@ -14,7 +14,6 @@ from typing_extensions import ParamSpec
|
||||
|
||||
# import custom ops, trigger op registration
|
||||
import vllm._C # noqa
|
||||
from vllm.attention.backends.abstract import AttentionType
|
||||
from vllm.attention.backends.registry import AttentionBackendEnum
|
||||
from vllm.logger import init_logger
|
||||
from vllm.utils.import_utils import import_pynvml
|
||||
@ -23,6 +22,7 @@ from vllm.utils.torch_utils import cuda_device_count_stateless
|
||||
from .interface import DeviceCapability, Platform, PlatformEnum
|
||||
|
||||
if TYPE_CHECKING:
|
||||
from vllm.attention.selector import AttentionSelectorConfig
|
||||
from vllm.config import VllmConfig
|
||||
from vllm.config.cache import CacheDType
|
||||
else:
|
||||
@ -258,16 +258,8 @@ class CudaPlatformBase(Platform):
|
||||
@classmethod
|
||||
def get_valid_backends(
|
||||
cls,
|
||||
head_size,
|
||||
dtype,
|
||||
kv_cache_dtype,
|
||||
block_size,
|
||||
use_mla,
|
||||
has_sink,
|
||||
use_sparse,
|
||||
use_mm_prefix,
|
||||
device_capability,
|
||||
attn_type,
|
||||
device_capability: DeviceCapability,
|
||||
attn_selector_config: "AttentionSelectorConfig",
|
||||
) -> tuple[
|
||||
list[tuple["AttentionBackendEnum", int]],
|
||||
dict["AttentionBackendEnum", list[str]],
|
||||
@ -275,21 +267,15 @@ class CudaPlatformBase(Platform):
|
||||
valid_backends_priorities = []
|
||||
invalid_reasons = {}
|
||||
|
||||
backend_priorities = _get_backend_priorities(use_mla, device_capability)
|
||||
backend_priorities = _get_backend_priorities(
|
||||
attn_selector_config.use_mla, device_capability
|
||||
)
|
||||
for priority, backend in enumerate(backend_priorities):
|
||||
try:
|
||||
backend_class = backend.get_class()
|
||||
invalid_reasons_i = backend_class.validate_configuration(
|
||||
head_size,
|
||||
dtype,
|
||||
kv_cache_dtype,
|
||||
block_size,
|
||||
use_mla,
|
||||
has_sink,
|
||||
use_sparse,
|
||||
use_mm_prefix,
|
||||
device_capability,
|
||||
attn_type,
|
||||
device_capability=device_capability,
|
||||
**attn_selector_config._asdict(),
|
||||
)
|
||||
except ImportError:
|
||||
invalid_reasons_i = ["ImportError"]
|
||||
@ -304,37 +290,19 @@ class CudaPlatformBase(Platform):
|
||||
def get_attn_backend_cls(
|
||||
cls,
|
||||
selected_backend: "AttentionBackendEnum",
|
||||
head_size: int,
|
||||
dtype: torch.dtype,
|
||||
kv_cache_dtype: "CacheDType | None",
|
||||
block_size: int | None,
|
||||
use_mla: bool,
|
||||
has_sink: bool,
|
||||
use_sparse: bool,
|
||||
use_mm_prefix: bool,
|
||||
attn_type: str | None = None,
|
||||
attn_selector_config: "AttentionSelectorConfig",
|
||||
) -> str:
|
||||
if attn_type is None:
|
||||
attn_type = AttentionType.DECODER
|
||||
|
||||
device_capability = cls.get_device_capability()
|
||||
assert device_capability is not None
|
||||
|
||||
attn_selector_config = attn_selector_config._replace(block_size=None)
|
||||
# First try checking just the selected backend, if there is one.
|
||||
if selected_backend is not None:
|
||||
try:
|
||||
backend_class = selected_backend.get_class()
|
||||
invalid_reasons = backend_class.validate_configuration(
|
||||
head_size,
|
||||
dtype,
|
||||
kv_cache_dtype,
|
||||
None,
|
||||
use_mla,
|
||||
has_sink,
|
||||
use_sparse,
|
||||
use_mm_prefix,
|
||||
device_capability,
|
||||
attn_type,
|
||||
device_capability=device_capability,
|
||||
**attn_selector_config._asdict(),
|
||||
)
|
||||
except ImportError:
|
||||
invalid_reasons = ["ImportError"]
|
||||
@ -350,16 +318,8 @@ class CudaPlatformBase(Platform):
|
||||
# No selected backend or the selected backend is invalid,
|
||||
# so we try finding a valid backend.
|
||||
valid_backends_priorities, invalid_reasons = cls.get_valid_backends(
|
||||
head_size,
|
||||
dtype,
|
||||
kv_cache_dtype,
|
||||
None,
|
||||
use_mla,
|
||||
has_sink,
|
||||
use_sparse,
|
||||
use_mm_prefix,
|
||||
device_capability,
|
||||
attn_type,
|
||||
device_capability=device_capability,
|
||||
attn_selector_config=attn_selector_config,
|
||||
)
|
||||
reasons_str = (
|
||||
"{"
|
||||
@ -369,11 +329,7 @@ class CudaPlatformBase(Platform):
|
||||
)
|
||||
+ "}"
|
||||
)
|
||||
config_str = (
|
||||
f"head_size: {head_size}, dtype: {dtype}, "
|
||||
f"kv_cache_dtype: {kv_cache_dtype}, block_size: {block_size}, "
|
||||
f"use_mla: {use_mla}, has_sink: {has_sink}, use_sparse: {use_sparse}"
|
||||
)
|
||||
config_str = attn_selector_config.__repr__()
|
||||
logger.debug_once(
|
||||
f"Some attention backends are not valid for {cls.device_name} with "
|
||||
f"{config_str}. Reasons: {reasons_str}."
|
||||
|
||||
@ -18,8 +18,8 @@ from vllm.logger import init_logger
|
||||
if TYPE_CHECKING:
|
||||
from torch.distributed import PrefixStore, ProcessGroup
|
||||
|
||||
from vllm.attention.selector import AttentionSelectorConfig
|
||||
from vllm.config import VllmConfig
|
||||
from vllm.config.cache import CacheDType
|
||||
from vllm.inputs import ProcessorInputs, PromptType
|
||||
from vllm.pooling_params import PoolingParams
|
||||
from vllm.sampling_params import SamplingParams
|
||||
@ -226,15 +226,7 @@ class Platform:
|
||||
def get_attn_backend_cls(
|
||||
cls,
|
||||
selected_backend: "AttentionBackendEnum",
|
||||
head_size: int,
|
||||
dtype: torch.dtype,
|
||||
kv_cache_dtype: "CacheDType | None",
|
||||
block_size: int,
|
||||
use_mla: bool,
|
||||
has_sink: bool,
|
||||
use_sparse: bool,
|
||||
use_mm_prefix: bool,
|
||||
attn_type: str | None = None,
|
||||
attn_selector_config: "AttentionSelectorConfig",
|
||||
) -> str:
|
||||
"""Get the attention backend class of a device."""
|
||||
return ""
|
||||
|
||||
@ -15,6 +15,7 @@ from vllm.utils.torch_utils import cuda_device_count_stateless
|
||||
from .interface import DeviceCapability, Platform, PlatformEnum
|
||||
|
||||
if TYPE_CHECKING:
|
||||
from vllm.attention.selector import AttentionSelectorConfig
|
||||
from vllm.config import VllmConfig
|
||||
|
||||
logger = init_logger(__name__)
|
||||
@ -190,21 +191,16 @@ class RocmPlatform(Platform):
|
||||
@classmethod
|
||||
def get_attn_backend_cls(
|
||||
cls,
|
||||
selected_backend,
|
||||
head_size,
|
||||
dtype,
|
||||
kv_cache_dtype,
|
||||
block_size,
|
||||
use_mla,
|
||||
has_sink,
|
||||
use_sparse,
|
||||
use_mm_prefix,
|
||||
attn_type: str | None = None,
|
||||
selected_backend: "AttentionBackendEnum",
|
||||
attn_selector_config: "AttentionSelectorConfig",
|
||||
) -> str:
|
||||
from vllm._aiter_ops import rocm_aiter_ops
|
||||
|
||||
if use_sparse:
|
||||
if kv_cache_dtype.startswith("fp8"):
|
||||
block_size = attn_selector_config.block_size
|
||||
kv_cache_dtype = attn_selector_config.kv_cache_dtype
|
||||
|
||||
if attn_selector_config.use_sparse:
|
||||
if kv_cache_dtype and kv_cache_dtype.startswith("fp8"):
|
||||
raise ValueError(
|
||||
"ROCMAiterMLASparseBackend doesn't support fp8 kv_cache_dtype."
|
||||
)
|
||||
@ -214,7 +210,7 @@ class RocmPlatform(Platform):
|
||||
logger.info_once("Using Sparse MLA backend on V1 engine.")
|
||||
return AttentionBackendEnum.ROCM_AITER_MLA_SPARSE.get_path()
|
||||
|
||||
if use_mla:
|
||||
if attn_selector_config.use_mla:
|
||||
if selected_backend is None:
|
||||
selected_backend = (
|
||||
AttentionBackendEnum.ROCM_AITER_MLA
|
||||
|
||||
@ -16,6 +16,7 @@ from .interface import Platform, PlatformEnum
|
||||
if TYPE_CHECKING:
|
||||
from typing import TypeAlias
|
||||
|
||||
from vllm.attention.selector import AttentionSelectorConfig
|
||||
from vllm.config import VllmConfig
|
||||
from vllm.config.cache import BlockSize
|
||||
from vllm.pooling_params import PoolingParams
|
||||
@ -57,17 +58,9 @@ class TpuPlatform(Platform):
|
||||
def get_attn_backend_cls(
|
||||
cls,
|
||||
selected_backend: "AttentionBackendEnum",
|
||||
head_size: int,
|
||||
dtype: torch.dtype,
|
||||
kv_cache_dtype: str | None,
|
||||
block_size: int,
|
||||
use_mla: bool,
|
||||
has_sink: bool,
|
||||
use_sparse: bool,
|
||||
use_mm_prefix: bool,
|
||||
attn_type: str | None = None,
|
||||
attn_selector_config: "AttentionSelectorConfig",
|
||||
) -> str:
|
||||
if use_sparse:
|
||||
if attn_selector_config.use_sparse:
|
||||
raise NotImplementedError("Sparse Attention is not supported on TPU.")
|
||||
if selected_backend != AttentionBackendEnum.PALLAS:
|
||||
logger.info("Cannot use %s backend on TPU.", selected_backend)
|
||||
|
||||
@ -14,6 +14,7 @@ from vllm.logger import init_logger
|
||||
from .interface import DeviceCapability, Platform, PlatformEnum
|
||||
|
||||
if TYPE_CHECKING:
|
||||
from vllm.attention.selector import AttentionSelectorConfig
|
||||
from vllm.config import VllmConfig
|
||||
else:
|
||||
VllmConfig = None
|
||||
@ -42,15 +43,7 @@ class XPUPlatform(Platform):
|
||||
def get_attn_backend_cls(
|
||||
cls,
|
||||
selected_backend: "AttentionBackendEnum",
|
||||
head_size: int,
|
||||
dtype: torch.dtype,
|
||||
kv_cache_dtype: str | None,
|
||||
block_size: int,
|
||||
use_mla: bool,
|
||||
has_sink: bool,
|
||||
use_sparse: bool,
|
||||
use_mm_prefix: bool,
|
||||
attn_type: str | None = None,
|
||||
attn_selector_config: "AttentionSelectorConfig",
|
||||
) -> str:
|
||||
from vllm.v1.attention.backends.utils import set_kv_cache_layout
|
||||
|
||||
@ -60,7 +53,7 @@ class XPUPlatform(Platform):
|
||||
"only NHD layout is supported by XPU attention kernels."
|
||||
)
|
||||
|
||||
if use_sparse:
|
||||
if attn_selector_config.use_sparse:
|
||||
raise NotImplementedError("Sparse Attention is not supported on XPU.")
|
||||
if selected_backend == AttentionBackendEnum.TRITON_ATTN:
|
||||
logger.info_once("Using Triton backend.")
|
||||
@ -71,7 +64,7 @@ class XPUPlatform(Platform):
|
||||
elif selected_backend:
|
||||
raise ValueError(
|
||||
f"Invalid attention backend for {cls.device_name}, "
|
||||
f"with use_mla: {use_mla}"
|
||||
f"with use_mla: {attn_selector_config.use_mla}"
|
||||
)
|
||||
|
||||
logger.info("Using Flash Attention backend.")
|
||||
|
||||
Loading…
x
Reference in New Issue
Block a user