[Attention] Refactor FA block_size limitations to hybrid models only (#29084)

Signed-off-by: NickLucche <nlucches@redhat.com>
This commit is contained in:
Nicolò Lucchesi 2025-11-22 15:38:44 +01:00 committed by GitHub
parent 5f7209a793
commit 066209a045
No known key found for this signature in database
GPG Key ID: B5690EEEBB952194
17 changed files with 82 additions and 32 deletions

View File

@ -61,7 +61,7 @@ for backend in BACKENDS_TO_TEST:
BACKEND_BLOCK_SIZES = {}
for backend in BACKENDS_TO_TEST:
supported_sizes = backend.get_class().supported_kernel_block_sizes
supported_sizes = backend.get_class().get_supported_kernel_block_sizes()
if supported_sizes:
default_size = supported_sizes[0]
block_size = (

View File

@ -185,7 +185,9 @@ def _make_mock_backend_for_kernel_block_size(
supported_sizes: list[int | MultipleOf],
):
class _MockBackend:
supported_kernel_block_sizes = supported_sizes
@staticmethod
def get_supported_kernel_block_sizes():
return supported_sizes
return _MockBackend()

View File

@ -46,9 +46,12 @@ class AttentionBackend(ABC):
# makes sure the output tensor is allocated inside the cudagraph.
accept_output_buffer: bool = False
supported_dtypes: ClassVar[list[torch.dtype]] = [torch.float16, torch.bfloat16]
supported_kernel_block_sizes: ClassVar[list[int | MultipleOf]] = [MultipleOf(1)]
supported_kv_cache_dtypes: ClassVar[list["CacheDType"]] = ["auto"]
@staticmethod
def get_supported_kernel_block_sizes() -> list[int | MultipleOf]:
return [MultipleOf(1)]
@staticmethod
@abstractmethod
def get_name() -> str:
@ -142,10 +145,11 @@ class AttentionBackend(ABC):
if block_size not in valid_sizes:
return False
if not cls.supported_kernel_block_sizes:
supported_kernel_block_sizes = cls.get_supported_kernel_block_sizes()
if not supported_kernel_block_sizes:
return True
for supported_size in cls.supported_kernel_block_sizes:
for supported_size in supported_kernel_block_sizes:
if isinstance(supported_size, MultipleOf):
supported_size = supported_size.base
# With hybrid_blocks feature, the framework-level block size

View File

@ -32,7 +32,7 @@ if is_flash_attn_varlen_func_available():
get_scheduler_metadata,
reshape_and_cache_flash,
)
from vllm.config import VllmConfig, get_layers_from_vllm_config
from vllm.config import VllmConfig, get_current_vllm_config, get_layers_from_vllm_config
from vllm.config.cache import CacheDType
from vllm.distributed.parallel_state import get_dcp_group
from vllm.logger import init_logger
@ -56,11 +56,26 @@ logger = init_logger(__name__)
class FlashAttentionBackend(AttentionBackend):
accept_output_buffer: bool = True
supported_dtypes: ClassVar[list[torch.dtype]] = [torch.float16, torch.bfloat16]
# NOTE(tdoublep): while in principle, FA supports
# MultipleOf(16), these are the block sizes that do not
# suffer from the NaN propagation problem described here:
# https://github.com/Dao-AILab/flash-attention/issues/1974
supported_kernel_block_sizes: ClassVar[list[int | MultipleOf]] = [16, 32, 64]
@staticmethod
def get_supported_kernel_block_sizes() -> list[int | MultipleOf]:
vllm_config = get_current_vllm_config()
model_config = vllm_config.model_config
cache_config = vllm_config.cache_config
if (
model_config
and model_config.is_hybrid
and (
cache_config.mamba_ssm_cache_dtype == "float32"
or cache_config.mamba_cache_dtype == "float32"
)
):
# NOTE(tdoublep): while in principle, FA supports
# MultipleOf(16), these are the block sizes that do not
# suffer from the NaN propagation problem described here:
# https://github.com/Dao-AILab/flash-attention/issues/1974
return [16, 32, 64]
return [MultipleOf(16)]
@staticmethod
def get_name() -> str:

View File

@ -16,7 +16,6 @@ from flashinfer import (
from flashinfer.decode import _get_range_buf, trtllm_batch_decode_with_kv_cache
from flashinfer.prefill import trtllm_batch_context_with_kv_cache
from flashinfer.utils import FP4Tensor
from typing_extensions import override
from vllm import envs
from vllm.attention.backends.abstract import (
@ -275,10 +274,6 @@ class BatchDCPPrefillWrapper:
class FlashInferBackend(AttentionBackend):
accept_output_buffer: bool = True
supported_dtypes: ClassVar[list[torch.dtype]] = [torch.float16, torch.bfloat16]
# Note: Not sure for all platforms,
# but on Blackwell, only support a page size of
# 16, 32, 64
supported_kernel_block_sizes: ClassVar[list[int | MultipleOf]] = [16, 32, 64]
supported_kv_cache_dtypes: ClassVar[list[CacheDType]] = [
"auto",
"fp8",
@ -286,6 +281,12 @@ class FlashInferBackend(AttentionBackend):
"fp8_e5m2",
]
@staticmethod
def get_supported_kernel_block_sizes() -> list[int | MultipleOf]:
# Note: Not sure for all platforms, but on Blackwell,
# only support a page size of 16, 32, 64.
return [16, 32, 64]
@staticmethod
def get_name() -> str:
return "FLASHINFER"
@ -566,7 +567,6 @@ class FlashInferMetadataBuilder(AttentionMetadataBuilder[FlashInferMetadata]):
)
@classmethod
@override
def get_cudagraph_support(
cls: type["FlashInferMetadataBuilder"],
vllm_config: VllmConfig,

View File

@ -36,13 +36,16 @@ class CutlassMLAMetadataBuilder(MLACommonMetadataBuilder[MLACommonMetadata]):
class CutlassMLABackend(MLACommonBackend):
supported_dtypes: ClassVar[list[torch.dtype]] = [torch.float16, torch.bfloat16]
supported_kernel_block_sizes: ClassVar[list[int | MultipleOf]] = [128]
supported_kv_cache_dtypes: ClassVar[list[CacheDType]] = [
"auto",
"fp8",
"fp8_e4m3",
]
@staticmethod
def get_supported_kernel_block_sizes() -> list[int | MultipleOf]:
return [128]
@staticmethod
def get_name() -> str:
return "CUTLASS_MLA"

View File

@ -41,9 +41,12 @@ logger = init_logger(__name__)
class FlashAttnMLABackend(MLACommonBackend):
supported_dtypes: ClassVar[list[torch.dtype]] = [torch.float16, torch.bfloat16]
supported_kernel_block_sizes: ClassVar[list[int | MultipleOf]] = [MultipleOf(16)]
supported_kv_cache_dtypes: ClassVar[list[CacheDType]] = ["auto"]
@staticmethod
def get_supported_kernel_block_sizes() -> list[int | MultipleOf]:
return [MultipleOf(16)]
@staticmethod
def get_name() -> str:
return "FLASH_ATTN_MLA"

View File

@ -35,13 +35,16 @@ class FlashInferMLAMetadataBuilder(MLACommonMetadataBuilder[MLACommonMetadata]):
class FlashInferMLABackend(MLACommonBackend):
supported_dtypes: ClassVar[list[torch.dtype]] = [torch.float16, torch.bfloat16]
supported_kernel_block_sizes: ClassVar[list[int | MultipleOf]] = [32, 64]
supported_kv_cache_dtypes: ClassVar[list[CacheDType]] = [
"auto",
"fp8",
"fp8_e4m3",
]
@staticmethod
def get_supported_kernel_block_sizes() -> list[int | MultipleOf]:
return [32, 64]
@staticmethod
def get_name() -> str:
return "FLASHINFER_MLA"

View File

@ -39,13 +39,16 @@ logger = init_logger(__name__)
class FlashMLABackend(MLACommonBackend):
supported_dtypes: ClassVar[list[torch.dtype]] = [torch.float16, torch.bfloat16]
supported_kernel_block_sizes: ClassVar[list[int | MultipleOf]] = [64]
supported_kv_cache_dtypes: ClassVar[list[CacheDType]] = [
"auto",
"fp8",
"fp8_e4m3",
]
@staticmethod
def get_supported_kernel_block_sizes() -> list[int | MultipleOf]:
return [64]
@staticmethod
def get_name() -> str:
return "FLASHMLA"

View File

@ -55,9 +55,12 @@ structured as:
class FlashMLASparseBackend(AttentionBackend):
accept_output_buffer: bool = True
supported_dtypes: ClassVar[list[torch.dtype]] = [torch.bfloat16]
supported_kernel_block_sizes: ClassVar[list[int | MultipleOf]] = [64]
supported_kv_cache_dtypes: ClassVar[list[CacheDType]] = ["auto", "fp8_ds_mla"]
@staticmethod
def get_supported_kernel_block_sizes() -> list[int | MultipleOf]:
return [64]
@staticmethod
def get_name() -> str:
return "FLASHMLA_SPARSE"

View File

@ -24,9 +24,9 @@ logger = init_logger(__name__)
class DeepseekV32IndexerBackend(AttentionBackend):
supported_kernel_block_sizes: ClassVar[list[int | MultipleOf]] = [
1 if current_platform.is_rocm() else 64
]
@staticmethod
def get_supported_kernel_block_sizes() -> list[int | MultipleOf]:
return [1 if current_platform.is_rocm() else 64]
@classmethod
def get_supported_head_sizes(cls) -> list[int]:

View File

@ -21,7 +21,9 @@ from vllm.v1.kv_cache_interface import AttentionSpec
class AiterMLABackend(MLACommonBackend):
supported_kernel_block_sizes: ClassVar[list[int | MultipleOf]] = [1]
@staticmethod
def get_supported_kernel_block_sizes() -> list[int | MultipleOf]:
return [1]
@staticmethod
def get_name() -> str:

View File

@ -447,7 +447,10 @@ class AiterFlashAttentionMetadataBuilder(
class AiterFlashAttentionBackend(AttentionBackend):
accept_output_buffer: bool = True
supported_dtypes: ClassVar[list[torch.dtype]] = [torch.float16, torch.bfloat16]
supported_kernel_block_sizes: ClassVar[list[int | MultipleOf]] = [MultipleOf(16)]
@staticmethod
def get_supported_kernel_block_sizes() -> list[int | MultipleOf]:
return [MultipleOf(16)]
@classmethod
def get_supported_head_sizes(cls) -> list[int]:

View File

@ -31,7 +31,10 @@ logger = init_logger(__name__)
class TreeAttentionBackend(AttentionBackend):
accept_output_buffer: bool = True
supported_dtypes: ClassVar[list[torch.dtype]] = [torch.float16, torch.bfloat16]
supported_kernel_block_sizes: ClassVar[list[int | MultipleOf]] = [MultipleOf(16)]
@staticmethod
def get_supported_kernel_block_sizes() -> list[int | MultipleOf]:
return [MultipleOf(16)]
@classmethod
def get_supported_head_sizes(cls) -> list[int]:

View File

@ -154,7 +154,6 @@ class TritonAttentionBackend(AttentionBackend):
torch.bfloat16,
torch.float32,
]
supported_kernel_block_sizes: ClassVar[list[int | MultipleOf]] = [MultipleOf(16)]
supported_kv_cache_dtypes: ClassVar[list[CacheDType]] = [
"auto",
"fp8",
@ -162,6 +161,10 @@ class TritonAttentionBackend(AttentionBackend):
"fp8_e5m2",
]
@staticmethod
def get_supported_kernel_block_sizes() -> list[int | MultipleOf]:
return [MultipleOf(16)]
@staticmethod
def get_name() -> str:
return "TRITON_ATTN"

View File

@ -42,7 +42,10 @@ logger = init_logger(__name__)
class XFormersAttentionBackend(AttentionBackend):
accept_output_buffer: bool = True
supported_dtypes: ClassVar[list[torch.dtype]] = [torch.float16, torch.bfloat16]
supported_kernel_block_sizes: ClassVar[list[int | MultipleOf]] = [MultipleOf(16)]
@staticmethod
def get_supported_kernel_block_sizes() -> list[int | MultipleOf]:
return [MultipleOf(16)]
@classmethod
def get_supported_head_sizes(cls) -> list[int]:

View File

@ -4618,7 +4618,7 @@ class GPUModelRunner(
"""
for backend in backends:
is_supported = False
for supported_size in backend.supported_kernel_block_sizes:
for supported_size in backend.get_supported_kernel_block_sizes():
if isinstance(supported_size, int):
if block_size == supported_size:
is_supported = True
@ -4649,7 +4649,7 @@ class GPUModelRunner(
all_int_supported_sizes = set(
supported_size
for backend in backends
for supported_size in backend.supported_kernel_block_sizes
for supported_size in backend.get_supported_kernel_block_sizes()
if isinstance(supported_size, int)
)