[Misc] Log the reason for falling back to FlexAttention (#20699)

Signed-off-by: DarkLight1337 <tlleungac@connect.ust.hk>
This commit is contained in:
Cyrus Leung 2025-07-14 19:16:51 +08:00 committed by GitHub
parent a4851cfe68
commit e8cc53af5e
No known key found for this signature in database
GPG Key ID: B5690EEEBB952194
10 changed files with 105 additions and 33 deletions

View File

@ -3,6 +3,7 @@
import os
from contextlib import contextmanager
from dataclasses import dataclass
from functools import cache
from typing import Generator, Optional, Union
@ -79,31 +80,61 @@ def get_global_forced_attn_backend() -> Optional[_Backend]:
return forced_attn_backend
def supports_head_size(
@dataclass(frozen=True)
class _IsSupported:
can_import: bool
head_size: bool
dtype: bool
def __bool__(self) -> bool:
return self.can_import and self.head_size and self.dtype
def is_attn_backend_supported(
attn_backend: Union[str, type[AttentionBackend]],
head_size: int,
) -> bool:
dtype: torch.dtype,
*,
allow_import_error: bool = True,
) -> _IsSupported:
if isinstance(attn_backend, str):
try:
attn_backend = resolve_obj_by_qualname(attn_backend)
except ImportError:
return False
if not allow_import_error:
raise
return _IsSupported(can_import=False, head_size=False, dtype=False)
assert isinstance(attn_backend, type)
# TODO: Update the interface once V0 is removed
if get_supported_head_sizes := getattr(attn_backend,
"get_supported_head_sizes", None):
return head_size in get_supported_head_sizes()
if validate_head_size := getattr(attn_backend, "validate_head_size", None):
is_head_size_supported = head_size in get_supported_head_sizes()
elif validate_head_size := getattr(attn_backend, "validate_head_size",
None):
try:
validate_head_size(head_size)
return True
is_head_size_supported = True
except Exception:
return False
is_head_size_supported = False
else:
raise NotImplementedError(f"{attn_backend.__name__} does not support "
"head size validation")
raise NotImplementedError(f"{attn_backend.__name__} does not support "
"head size validation")
if get_supported_dtypes := getattr(attn_backend, "get_supported_dtypes",
None):
is_dtype_supported = dtype in get_supported_dtypes()
else:
raise NotImplementedError(f"{attn_backend.__name__} does not support "
"dtype validation")
return _IsSupported(
can_import=True,
head_size=is_head_size_supported,
dtype=is_dtype_supported,
)
def get_attn_backend(

View File

@ -259,43 +259,56 @@ class CudaPlatformBase(Platform):
logger.info_once("Using Flash Attention backend on V1 engine.")
return FLASH_ATTN_V1
from vllm.attention.selector import supports_head_size
from vllm.attention.selector import is_attn_backend_supported
# Default backends for V1 engine
# FP32 is only supported by FlexAttention
if dtype not in (torch.float16, torch.bfloat16):
logger.info_once(
"Using FlexAttention backend for %s on V1 engine.",
dtype,
)
return FLEX_ATTENTION_V1
# Prefer FlashInfer for Blackwell GPUs if installed
if cls.is_device_capability(100) and \
supports_head_size(FLASHINFER_V1, head_size):
try:
import flashinfer # noqa: F401
if cls.is_device_capability(100):
if is_default_backend_supported := is_attn_backend_supported(
FLASHINFER_V1, head_size, dtype):
from vllm.v1.attention.backends.utils import (
set_kv_cache_layout)
logger.info_once(
"Using FlashInfer backend with HND KV cache layout on "
"V1 engine by default for Blackwell (SM 10.0) GPUs.")
set_kv_cache_layout("HND")
return FLASHINFER_V1
except ImportError:
logger.info_once(
if not is_default_backend_supported.can_import:
logger.warning_once(
"FlashInfer failed to import for V1 engine on "
"Blackwell (SM 10.0) GPUs; it is recommended to "
"install FlashInfer for better performance.")
pass
# FlashAttention is the default for SM 8.0+ GPUs
if cls.has_device_capability(80) and \
supports_head_size(FLASH_ATTN_V1, head_size):
logger.info_once("Using Flash Attention backend on V1 engine.")
return FLASH_ATTN_V1
logger.info_once("Using FlexAttention backend on V1 engine.")
# FlashAttention is the default for SM 8.0+ GPUs
if cls.has_device_capability(80):
if is_default_backend_supported := is_attn_backend_supported(
FLASH_ATTN_V1, head_size, dtype,
allow_import_error=False):
logger.info_once("Using Flash Attention backend on "
"V1 engine.")
return FLASH_ATTN_V1
# FlexAttention is the default for older GPUs
else:
logger.info_once("Using FlexAttention backend on V1 engine.")
return FLEX_ATTENTION_V1
assert not is_default_backend_supported
use_flex_attention_reason = {}
if not is_default_backend_supported.head_size:
use_flex_attention_reason["head_size"] = head_size
if not is_default_backend_supported.dtype:
use_flex_attention_reason["dtype"] = dtype
logger.info_once(
"Using FlexAttention backend for %s on V1 engine.",
", ".join(f"{k}={v}"
for k, v in use_flex_attention_reason.items()),
)
return FLEX_ATTENTION_V1
# Backends for V0 engine

View File

@ -1,10 +1,10 @@
# SPDX-License-Identifier: Apache-2.0
# SPDX-FileCopyrightText: Copyright contributors to the vLLM project
import re
from collections.abc import Sequence
from typing import Optional, Union
import regex as re
from transformers import PreTrainedTokenizerBase
from vllm.entrypoints.openai.protocol import (ChatCompletionRequest,

View File

@ -37,6 +37,10 @@ logger = init_logger(__name__)
class TorchSDPABackend(AttentionBackend):
accept_output_buffer: bool = False
@classmethod
def get_supported_dtypes(cls) -> list[torch.dtype]:
return [torch.float16, torch.bfloat16, torch.float32]
@classmethod
def validate_head_size(cls, head_size: int) -> None:
attn_impl = _get_paged_attn_impl()

View File

@ -44,6 +44,10 @@ class FlashAttentionBackend(AttentionBackend):
accept_output_buffer: bool = True
@classmethod
def get_supported_dtypes(cls) -> list[torch.dtype]:
return [torch.float16, torch.bfloat16]
@classmethod
def get_supported_head_sizes(cls) -> list[int]:
return [32, 64, 96, 128, 160, 192, 224, 256]

View File

@ -42,6 +42,10 @@ class FlashInferBackend(AttentionBackend):
accept_output_buffer: bool = True
cached_sm100a_supported: Optional[bool] = None
@classmethod
def get_supported_dtypes(cls) -> list[torch.dtype]:
return [torch.float16, torch.bfloat16]
@classmethod
def get_supported_head_sizes(cls) -> list[int]:
# https://github.com/flashinfer-ai/flashinfer/blob/3d55c71a62052c590c130897d3a3db49b14fcc34/include/flashinfer/utils.cuh#L157

View File

@ -42,6 +42,10 @@ def _offsets_to_doc_ids_tensor(offsets: torch.Tensor) -> torch.Tensor:
class FlexAttentionBackend(AttentionBackend):
accept_output_buffer: bool = True
@classmethod
def get_supported_dtypes(cls) -> list[torch.dtype]:
return [torch.float16, torch.bfloat16, torch.float32]
@classmethod
def validate_head_size(cls, head_size: int) -> None:
return # FlexAttention supports any head size

View File

@ -262,6 +262,10 @@ class MLACommonBackend(AttentionBackend):
) -> tuple[int, ...]:
return (num_blocks, block_size, head_size)
@classmethod
def get_supported_dtypes(cls) -> list[torch.dtype]:
return [torch.float16, torch.bfloat16]
@classmethod
def get_supported_head_sizes(cls) -> list[int]:
return [576]

View File

@ -314,6 +314,10 @@ class AiterFlashAttentionBackend(AttentionBackend):
accept_output_buffer: bool = True
@classmethod
def get_supported_dtypes(cls) -> list[torch.dtype]:
return [torch.float16, torch.bfloat16]
@classmethod
def get_supported_head_sizes(cls) -> list[int]:
return [32, 64, 96, 128, 160, 192, 224, 256]

View File

@ -190,6 +190,10 @@ class TritonAttentionBackend(AttentionBackend):
accept_output_buffer: bool = True
@classmethod
def get_supported_dtypes(cls) -> list[torch.dtype]:
return [torch.float16, torch.bfloat16]
@classmethod
def get_supported_head_sizes(cls) -> list[int]:
return [32, 64, 96, 128, 160, 192, 224, 256]