From e8cc53af5e17205470c04f442e67f276e08623a1 Mon Sep 17 00:00:00 2001 From: Cyrus Leung Date: Mon, 14 Jul 2025 19:16:51 +0800 Subject: [PATCH] [Misc] Log the reason for falling back to FlexAttention (#20699) Signed-off-by: DarkLight1337 --- vllm/attention/selector.py | 49 ++++++++++++--- vllm/platforms/cuda.py | 59 +++++++++++-------- .../hunyuan_a13b_reasoning_parser.py | 2 +- vllm/v1/attention/backends/cpu_attn.py | 4 ++ vllm/v1/attention/backends/flash_attn.py | 4 ++ vllm/v1/attention/backends/flashinfer.py | 4 ++ vllm/v1/attention/backends/flex_attention.py | 4 ++ vllm/v1/attention/backends/mla/common.py | 4 ++ vllm/v1/attention/backends/rocm_aiter_fa.py | 4 ++ vllm/v1/attention/backends/triton_attn.py | 4 ++ 10 files changed, 105 insertions(+), 33 deletions(-) diff --git a/vllm/attention/selector.py b/vllm/attention/selector.py index df14aea729f3c..4d4886d02b78e 100644 --- a/vllm/attention/selector.py +++ b/vllm/attention/selector.py @@ -3,6 +3,7 @@ import os from contextlib import contextmanager +from dataclasses import dataclass from functools import cache from typing import Generator, Optional, Union @@ -79,31 +80,61 @@ def get_global_forced_attn_backend() -> Optional[_Backend]: return forced_attn_backend -def supports_head_size( +@dataclass(frozen=True) +class _IsSupported: + can_import: bool + head_size: bool + dtype: bool + + def __bool__(self) -> bool: + return self.can_import and self.head_size and self.dtype + + +def is_attn_backend_supported( attn_backend: Union[str, type[AttentionBackend]], head_size: int, -) -> bool: + dtype: torch.dtype, + *, + allow_import_error: bool = True, +) -> _IsSupported: if isinstance(attn_backend, str): try: attn_backend = resolve_obj_by_qualname(attn_backend) except ImportError: - return False + if not allow_import_error: + raise + + return _IsSupported(can_import=False, head_size=False, dtype=False) assert isinstance(attn_backend, type) # TODO: Update the interface once V0 is removed if get_supported_head_sizes := getattr(attn_backend, "get_supported_head_sizes", None): - return head_size in get_supported_head_sizes() - if validate_head_size := getattr(attn_backend, "validate_head_size", None): + is_head_size_supported = head_size in get_supported_head_sizes() + elif validate_head_size := getattr(attn_backend, "validate_head_size", + None): try: validate_head_size(head_size) - return True + is_head_size_supported = True except Exception: - return False + is_head_size_supported = False + else: + raise NotImplementedError(f"{attn_backend.__name__} does not support " + "head size validation") - raise NotImplementedError(f"{attn_backend.__name__} does not support " - "head size validation") + if get_supported_dtypes := getattr(attn_backend, "get_supported_dtypes", + None): + is_dtype_supported = dtype in get_supported_dtypes() + else: + raise NotImplementedError(f"{attn_backend.__name__} does not support " + "dtype validation") + + return _IsSupported( + can_import=True, + head_size=is_head_size_supported, + dtype=is_dtype_supported, + ) def get_attn_backend( diff --git a/vllm/platforms/cuda.py b/vllm/platforms/cuda.py index 878f8f77edffa..75b10643c2b5d 100644 --- a/vllm/platforms/cuda.py +++ b/vllm/platforms/cuda.py @@ -259,43 +259,56 @@ class CudaPlatformBase(Platform): logger.info_once("Using Flash Attention backend on V1 engine.") return FLASH_ATTN_V1 - from vllm.attention.selector import supports_head_size + from vllm.attention.selector import is_attn_backend_supported # Default backends for V1 engine - # FP32 is only supported by FlexAttention - if dtype not in (torch.float16, torch.bfloat16): - logger.info_once( - "Using FlexAttention backend for %s on V1 engine.", - dtype, - ) - return FLEX_ATTENTION_V1 - # Prefer FlashInfer for Blackwell GPUs if installed - if cls.is_device_capability(100) and \ - supports_head_size(FLASHINFER_V1, head_size): - try: - import flashinfer # noqa: F401 - + if cls.is_device_capability(100): + if is_default_backend_supported := is_attn_backend_supported( + FLASHINFER_V1, head_size, dtype): from vllm.v1.attention.backends.utils import ( set_kv_cache_layout) + logger.info_once( "Using FlashInfer backend with HND KV cache layout on " "V1 engine by default for Blackwell (SM 10.0) GPUs.") set_kv_cache_layout("HND") + return FLASHINFER_V1 - except ImportError: - logger.info_once( + + if not is_default_backend_supported.can_import: + logger.warning_once( "FlashInfer failed to import for V1 engine on " "Blackwell (SM 10.0) GPUs; it is recommended to " "install FlashInfer for better performance.") - pass - # FlashAttention is the default for SM 8.0+ GPUs - if cls.has_device_capability(80) and \ - supports_head_size(FLASH_ATTN_V1, head_size): - logger.info_once("Using Flash Attention backend on V1 engine.") - return FLASH_ATTN_V1 - logger.info_once("Using FlexAttention backend on V1 engine.") + # FlashAttention is the default for SM 8.0+ GPUs + if cls.has_device_capability(80): + if is_default_backend_supported := is_attn_backend_supported( + FLASH_ATTN_V1, head_size, dtype, + allow_import_error=False): + logger.info_once("Using Flash Attention backend on " + "V1 engine.") + return FLASH_ATTN_V1 + + # FlexAttention is the default for older GPUs + else: + logger.info_once("Using FlexAttention backend on V1 engine.") + return FLEX_ATTENTION_V1 + + assert not is_default_backend_supported + + use_flex_attention_reason = {} + if not is_default_backend_supported.head_size: + use_flex_attention_reason["head_size"] = head_size + if not is_default_backend_supported.dtype: + use_flex_attention_reason["dtype"] = dtype + + logger.info_once( + "Using FlexAttention backend for %s on V1 engine.", + ", ".join(f"{k}={v}" + for k, v in use_flex_attention_reason.items()), + ) return FLEX_ATTENTION_V1 # Backends for V0 engine diff --git a/vllm/reasoning/hunyuan_a13b_reasoning_parser.py b/vllm/reasoning/hunyuan_a13b_reasoning_parser.py index 598a0e97e515b..fb29d51eae8cf 100644 --- a/vllm/reasoning/hunyuan_a13b_reasoning_parser.py +++ b/vllm/reasoning/hunyuan_a13b_reasoning_parser.py @@ -1,10 +1,10 @@ # SPDX-License-Identifier: Apache-2.0 # SPDX-FileCopyrightText: Copyright contributors to the vLLM project -import re from collections.abc import Sequence from typing import Optional, Union +import regex as re from transformers import PreTrainedTokenizerBase from vllm.entrypoints.openai.protocol import (ChatCompletionRequest, diff --git a/vllm/v1/attention/backends/cpu_attn.py b/vllm/v1/attention/backends/cpu_attn.py index d6270fbf31969..f1c6bdfc1c941 100644 --- a/vllm/v1/attention/backends/cpu_attn.py +++ b/vllm/v1/attention/backends/cpu_attn.py @@ -37,6 +37,10 @@ logger = init_logger(__name__) class TorchSDPABackend(AttentionBackend): accept_output_buffer: bool = False + @classmethod + def get_supported_dtypes(cls) -> list[torch.dtype]: + return [torch.float16, torch.bfloat16, torch.float32] + @classmethod def validate_head_size(cls, head_size: int) -> None: attn_impl = _get_paged_attn_impl() diff --git a/vllm/v1/attention/backends/flash_attn.py b/vllm/v1/attention/backends/flash_attn.py index fbc13c06c65aa..552c2caf2fa81 100755 --- a/vllm/v1/attention/backends/flash_attn.py +++ b/vllm/v1/attention/backends/flash_attn.py @@ -44,6 +44,10 @@ class FlashAttentionBackend(AttentionBackend): accept_output_buffer: bool = True + @classmethod + def get_supported_dtypes(cls) -> list[torch.dtype]: + return [torch.float16, torch.bfloat16] + @classmethod def get_supported_head_sizes(cls) -> list[int]: return [32, 64, 96, 128, 160, 192, 224, 256] diff --git a/vllm/v1/attention/backends/flashinfer.py b/vllm/v1/attention/backends/flashinfer.py index 4ae595c976b3e..f922e6e4c9e8b 100755 --- a/vllm/v1/attention/backends/flashinfer.py +++ b/vllm/v1/attention/backends/flashinfer.py @@ -42,6 +42,10 @@ class FlashInferBackend(AttentionBackend): accept_output_buffer: bool = True cached_sm100a_supported: Optional[bool] = None + @classmethod + def get_supported_dtypes(cls) -> list[torch.dtype]: + return [torch.float16, torch.bfloat16] + @classmethod def get_supported_head_sizes(cls) -> list[int]: # https://github.com/flashinfer-ai/flashinfer/blob/3d55c71a62052c590c130897d3a3db49b14fcc34/include/flashinfer/utils.cuh#L157 diff --git a/vllm/v1/attention/backends/flex_attention.py b/vllm/v1/attention/backends/flex_attention.py index a8c5f464aa326..f0f54c28831f4 100644 --- a/vllm/v1/attention/backends/flex_attention.py +++ b/vllm/v1/attention/backends/flex_attention.py @@ -42,6 +42,10 @@ def _offsets_to_doc_ids_tensor(offsets: torch.Tensor) -> torch.Tensor: class FlexAttentionBackend(AttentionBackend): accept_output_buffer: bool = True + @classmethod + def get_supported_dtypes(cls) -> list[torch.dtype]: + return [torch.float16, torch.bfloat16, torch.float32] + @classmethod def validate_head_size(cls, head_size: int) -> None: return # FlexAttention supports any head size diff --git a/vllm/v1/attention/backends/mla/common.py b/vllm/v1/attention/backends/mla/common.py index 970de229e139e..1232f73430f8f 100644 --- a/vllm/v1/attention/backends/mla/common.py +++ b/vllm/v1/attention/backends/mla/common.py @@ -262,6 +262,10 @@ class MLACommonBackend(AttentionBackend): ) -> tuple[int, ...]: return (num_blocks, block_size, head_size) + @classmethod + def get_supported_dtypes(cls) -> list[torch.dtype]: + return [torch.float16, torch.bfloat16] + @classmethod def get_supported_head_sizes(cls) -> list[int]: return [576] diff --git a/vllm/v1/attention/backends/rocm_aiter_fa.py b/vllm/v1/attention/backends/rocm_aiter_fa.py index 6a78b03dce86e..dd86e56885edb 100644 --- a/vllm/v1/attention/backends/rocm_aiter_fa.py +++ b/vllm/v1/attention/backends/rocm_aiter_fa.py @@ -314,6 +314,10 @@ class AiterFlashAttentionBackend(AttentionBackend): accept_output_buffer: bool = True + @classmethod + def get_supported_dtypes(cls) -> list[torch.dtype]: + return [torch.float16, torch.bfloat16] + @classmethod def get_supported_head_sizes(cls) -> list[int]: return [32, 64, 96, 128, 160, 192, 224, 256] diff --git a/vllm/v1/attention/backends/triton_attn.py b/vllm/v1/attention/backends/triton_attn.py index cdaff2f6a40fa..7dc90a6a97e76 100644 --- a/vllm/v1/attention/backends/triton_attn.py +++ b/vllm/v1/attention/backends/triton_attn.py @@ -190,6 +190,10 @@ class TritonAttentionBackend(AttentionBackend): accept_output_buffer: bool = True + @classmethod + def get_supported_dtypes(cls) -> list[torch.dtype]: + return [torch.float16, torch.bfloat16] + @classmethod def get_supported_head_sizes(cls) -> list[int]: return [32, 64, 96, 128, 160, 192, 224, 256]