diff --git a/tests/v1/attention/test_mla_backends.py b/tests/v1/attention/test_mla_backends.py index 1bd05e6183dc..783e02ce89bd 100644 --- a/tests/v1/attention/test_mla_backends.py +++ b/tests/v1/attention/test_mla_backends.py @@ -61,7 +61,7 @@ for backend in BACKENDS_TO_TEST: BACKEND_BLOCK_SIZES = {} for backend in BACKENDS_TO_TEST: - supported_sizes = backend.get_class().supported_kernel_block_sizes + supported_sizes = backend.get_class().get_supported_kernel_block_sizes() if supported_sizes: default_size = supported_sizes[0] block_size = ( diff --git a/tests/v1/worker/test_gpu_model_runner.py b/tests/v1/worker/test_gpu_model_runner.py index 01c1364f7ee6..d0f1b703fcb9 100644 --- a/tests/v1/worker/test_gpu_model_runner.py +++ b/tests/v1/worker/test_gpu_model_runner.py @@ -185,7 +185,9 @@ def _make_mock_backend_for_kernel_block_size( supported_sizes: list[int | MultipleOf], ): class _MockBackend: - supported_kernel_block_sizes = supported_sizes + @staticmethod + def get_supported_kernel_block_sizes(): + return supported_sizes return _MockBackend() diff --git a/vllm/attention/backends/abstract.py b/vllm/attention/backends/abstract.py index 67ded8847524..bd7e81b15bfc 100644 --- a/vllm/attention/backends/abstract.py +++ b/vllm/attention/backends/abstract.py @@ -46,9 +46,12 @@ class AttentionBackend(ABC): # makes sure the output tensor is allocated inside the cudagraph. accept_output_buffer: bool = False supported_dtypes: ClassVar[list[torch.dtype]] = [torch.float16, torch.bfloat16] - supported_kernel_block_sizes: ClassVar[list[int | MultipleOf]] = [MultipleOf(1)] supported_kv_cache_dtypes: ClassVar[list["CacheDType"]] = ["auto"] + @staticmethod + def get_supported_kernel_block_sizes() -> list[int | MultipleOf]: + return [MultipleOf(1)] + @staticmethod @abstractmethod def get_name() -> str: @@ -142,10 +145,11 @@ class AttentionBackend(ABC): if block_size not in valid_sizes: return False - if not cls.supported_kernel_block_sizes: + supported_kernel_block_sizes = cls.get_supported_kernel_block_sizes() + if not supported_kernel_block_sizes: return True - for supported_size in cls.supported_kernel_block_sizes: + for supported_size in supported_kernel_block_sizes: if isinstance(supported_size, MultipleOf): supported_size = supported_size.base # With hybrid_blocks feature, the framework-level block size diff --git a/vllm/v1/attention/backends/flash_attn.py b/vllm/v1/attention/backends/flash_attn.py index 9fa6b1dfd19d..a9a4af5ac118 100755 --- a/vllm/v1/attention/backends/flash_attn.py +++ b/vllm/v1/attention/backends/flash_attn.py @@ -32,7 +32,7 @@ if is_flash_attn_varlen_func_available(): get_scheduler_metadata, reshape_and_cache_flash, ) -from vllm.config import VllmConfig, get_layers_from_vllm_config +from vllm.config import VllmConfig, get_current_vllm_config, get_layers_from_vllm_config from vllm.config.cache import CacheDType from vllm.distributed.parallel_state import get_dcp_group from vllm.logger import init_logger @@ -56,11 +56,26 @@ logger = init_logger(__name__) class FlashAttentionBackend(AttentionBackend): accept_output_buffer: bool = True supported_dtypes: ClassVar[list[torch.dtype]] = [torch.float16, torch.bfloat16] - # NOTE(tdoublep): while in principle, FA supports - # MultipleOf(16), these are the block sizes that do not - # suffer from the NaN propagation problem described here: - # https://github.com/Dao-AILab/flash-attention/issues/1974 - supported_kernel_block_sizes: ClassVar[list[int | MultipleOf]] = [16, 32, 64] + + @staticmethod + def get_supported_kernel_block_sizes() -> list[int | MultipleOf]: + vllm_config = get_current_vllm_config() + model_config = vllm_config.model_config + cache_config = vllm_config.cache_config + if ( + model_config + and model_config.is_hybrid + and ( + cache_config.mamba_ssm_cache_dtype == "float32" + or cache_config.mamba_cache_dtype == "float32" + ) + ): + # NOTE(tdoublep): while in principle, FA supports + # MultipleOf(16), these are the block sizes that do not + # suffer from the NaN propagation problem described here: + # https://github.com/Dao-AILab/flash-attention/issues/1974 + return [16, 32, 64] + return [MultipleOf(16)] @staticmethod def get_name() -> str: diff --git a/vllm/v1/attention/backends/flashinfer.py b/vllm/v1/attention/backends/flashinfer.py index e3f499216d7f..8159f4096107 100755 --- a/vllm/v1/attention/backends/flashinfer.py +++ b/vllm/v1/attention/backends/flashinfer.py @@ -16,7 +16,6 @@ from flashinfer import ( from flashinfer.decode import _get_range_buf, trtllm_batch_decode_with_kv_cache from flashinfer.prefill import trtllm_batch_context_with_kv_cache from flashinfer.utils import FP4Tensor -from typing_extensions import override from vllm import envs from vllm.attention.backends.abstract import ( @@ -275,10 +274,6 @@ class BatchDCPPrefillWrapper: class FlashInferBackend(AttentionBackend): accept_output_buffer: bool = True supported_dtypes: ClassVar[list[torch.dtype]] = [torch.float16, torch.bfloat16] - # Note: Not sure for all platforms, - # but on Blackwell, only support a page size of - # 16, 32, 64 - supported_kernel_block_sizes: ClassVar[list[int | MultipleOf]] = [16, 32, 64] supported_kv_cache_dtypes: ClassVar[list[CacheDType]] = [ "auto", "fp8", @@ -286,6 +281,12 @@ class FlashInferBackend(AttentionBackend): "fp8_e5m2", ] + @staticmethod + def get_supported_kernel_block_sizes() -> list[int | MultipleOf]: + # Note: Not sure for all platforms, but on Blackwell, + # only support a page size of 16, 32, 64. + return [16, 32, 64] + @staticmethod def get_name() -> str: return "FLASHINFER" @@ -566,7 +567,6 @@ class FlashInferMetadataBuilder(AttentionMetadataBuilder[FlashInferMetadata]): ) @classmethod - @override def get_cudagraph_support( cls: type["FlashInferMetadataBuilder"], vllm_config: VllmConfig, diff --git a/vllm/v1/attention/backends/mla/cutlass_mla.py b/vllm/v1/attention/backends/mla/cutlass_mla.py index 60cb5022a55e..5e3fbc0abf08 100644 --- a/vllm/v1/attention/backends/mla/cutlass_mla.py +++ b/vllm/v1/attention/backends/mla/cutlass_mla.py @@ -36,13 +36,16 @@ class CutlassMLAMetadataBuilder(MLACommonMetadataBuilder[MLACommonMetadata]): class CutlassMLABackend(MLACommonBackend): supported_dtypes: ClassVar[list[torch.dtype]] = [torch.float16, torch.bfloat16] - supported_kernel_block_sizes: ClassVar[list[int | MultipleOf]] = [128] supported_kv_cache_dtypes: ClassVar[list[CacheDType]] = [ "auto", "fp8", "fp8_e4m3", ] + @staticmethod + def get_supported_kernel_block_sizes() -> list[int | MultipleOf]: + return [128] + @staticmethod def get_name() -> str: return "CUTLASS_MLA" diff --git a/vllm/v1/attention/backends/mla/flashattn_mla.py b/vllm/v1/attention/backends/mla/flashattn_mla.py index 12639edc8b9a..d369814c10b6 100644 --- a/vllm/v1/attention/backends/mla/flashattn_mla.py +++ b/vllm/v1/attention/backends/mla/flashattn_mla.py @@ -41,9 +41,12 @@ logger = init_logger(__name__) class FlashAttnMLABackend(MLACommonBackend): supported_dtypes: ClassVar[list[torch.dtype]] = [torch.float16, torch.bfloat16] - supported_kernel_block_sizes: ClassVar[list[int | MultipleOf]] = [MultipleOf(16)] supported_kv_cache_dtypes: ClassVar[list[CacheDType]] = ["auto"] + @staticmethod + def get_supported_kernel_block_sizes() -> list[int | MultipleOf]: + return [MultipleOf(16)] + @staticmethod def get_name() -> str: return "FLASH_ATTN_MLA" diff --git a/vllm/v1/attention/backends/mla/flashinfer_mla.py b/vllm/v1/attention/backends/mla/flashinfer_mla.py index 52bb19e039e4..f02a4bb1ef35 100644 --- a/vllm/v1/attention/backends/mla/flashinfer_mla.py +++ b/vllm/v1/attention/backends/mla/flashinfer_mla.py @@ -35,13 +35,16 @@ class FlashInferMLAMetadataBuilder(MLACommonMetadataBuilder[MLACommonMetadata]): class FlashInferMLABackend(MLACommonBackend): supported_dtypes: ClassVar[list[torch.dtype]] = [torch.float16, torch.bfloat16] - supported_kernel_block_sizes: ClassVar[list[int | MultipleOf]] = [32, 64] supported_kv_cache_dtypes: ClassVar[list[CacheDType]] = [ "auto", "fp8", "fp8_e4m3", ] + @staticmethod + def get_supported_kernel_block_sizes() -> list[int | MultipleOf]: + return [32, 64] + @staticmethod def get_name() -> str: return "FLASHINFER_MLA" diff --git a/vllm/v1/attention/backends/mla/flashmla.py b/vllm/v1/attention/backends/mla/flashmla.py index 3aab1f9bb7fb..74a4cd843025 100644 --- a/vllm/v1/attention/backends/mla/flashmla.py +++ b/vllm/v1/attention/backends/mla/flashmla.py @@ -39,13 +39,16 @@ logger = init_logger(__name__) class FlashMLABackend(MLACommonBackend): supported_dtypes: ClassVar[list[torch.dtype]] = [torch.float16, torch.bfloat16] - supported_kernel_block_sizes: ClassVar[list[int | MultipleOf]] = [64] supported_kv_cache_dtypes: ClassVar[list[CacheDType]] = [ "auto", "fp8", "fp8_e4m3", ] + @staticmethod + def get_supported_kernel_block_sizes() -> list[int | MultipleOf]: + return [64] + @staticmethod def get_name() -> str: return "FLASHMLA" diff --git a/vllm/v1/attention/backends/mla/flashmla_sparse.py b/vllm/v1/attention/backends/mla/flashmla_sparse.py index 3f2cc8c38327..1eee1d225293 100644 --- a/vllm/v1/attention/backends/mla/flashmla_sparse.py +++ b/vllm/v1/attention/backends/mla/flashmla_sparse.py @@ -55,9 +55,12 @@ structured as: class FlashMLASparseBackend(AttentionBackend): accept_output_buffer: bool = True supported_dtypes: ClassVar[list[torch.dtype]] = [torch.bfloat16] - supported_kernel_block_sizes: ClassVar[list[int | MultipleOf]] = [64] supported_kv_cache_dtypes: ClassVar[list[CacheDType]] = ["auto", "fp8_ds_mla"] + @staticmethod + def get_supported_kernel_block_sizes() -> list[int | MultipleOf]: + return [64] + @staticmethod def get_name() -> str: return "FLASHMLA_SPARSE" diff --git a/vllm/v1/attention/backends/mla/indexer.py b/vllm/v1/attention/backends/mla/indexer.py index d38361e0fcbf..77f1ba00d5b0 100644 --- a/vllm/v1/attention/backends/mla/indexer.py +++ b/vllm/v1/attention/backends/mla/indexer.py @@ -24,9 +24,9 @@ logger = init_logger(__name__) class DeepseekV32IndexerBackend(AttentionBackend): - supported_kernel_block_sizes: ClassVar[list[int | MultipleOf]] = [ - 1 if current_platform.is_rocm() else 64 - ] + @staticmethod + def get_supported_kernel_block_sizes() -> list[int | MultipleOf]: + return [1 if current_platform.is_rocm() else 64] @classmethod def get_supported_head_sizes(cls) -> list[int]: diff --git a/vllm/v1/attention/backends/mla/rocm_aiter_mla.py b/vllm/v1/attention/backends/mla/rocm_aiter_mla.py index 6ccc1a341d56..56f9c7a281e7 100644 --- a/vllm/v1/attention/backends/mla/rocm_aiter_mla.py +++ b/vllm/v1/attention/backends/mla/rocm_aiter_mla.py @@ -21,7 +21,9 @@ from vllm.v1.kv_cache_interface import AttentionSpec class AiterMLABackend(MLACommonBackend): - supported_kernel_block_sizes: ClassVar[list[int | MultipleOf]] = [1] + @staticmethod + def get_supported_kernel_block_sizes() -> list[int | MultipleOf]: + return [1] @staticmethod def get_name() -> str: diff --git a/vllm/v1/attention/backends/rocm_aiter_fa.py b/vllm/v1/attention/backends/rocm_aiter_fa.py index ea611848b0e8..c8742e983520 100644 --- a/vllm/v1/attention/backends/rocm_aiter_fa.py +++ b/vllm/v1/attention/backends/rocm_aiter_fa.py @@ -447,7 +447,10 @@ class AiterFlashAttentionMetadataBuilder( class AiterFlashAttentionBackend(AttentionBackend): accept_output_buffer: bool = True supported_dtypes: ClassVar[list[torch.dtype]] = [torch.float16, torch.bfloat16] - supported_kernel_block_sizes: ClassVar[list[int | MultipleOf]] = [MultipleOf(16)] + + @staticmethod + def get_supported_kernel_block_sizes() -> list[int | MultipleOf]: + return [MultipleOf(16)] @classmethod def get_supported_head_sizes(cls) -> list[int]: diff --git a/vllm/v1/attention/backends/tree_attn.py b/vllm/v1/attention/backends/tree_attn.py index 1bf38ed225a4..523f759e05a2 100644 --- a/vllm/v1/attention/backends/tree_attn.py +++ b/vllm/v1/attention/backends/tree_attn.py @@ -31,7 +31,10 @@ logger = init_logger(__name__) class TreeAttentionBackend(AttentionBackend): accept_output_buffer: bool = True supported_dtypes: ClassVar[list[torch.dtype]] = [torch.float16, torch.bfloat16] - supported_kernel_block_sizes: ClassVar[list[int | MultipleOf]] = [MultipleOf(16)] + + @staticmethod + def get_supported_kernel_block_sizes() -> list[int | MultipleOf]: + return [MultipleOf(16)] @classmethod def get_supported_head_sizes(cls) -> list[int]: diff --git a/vllm/v1/attention/backends/triton_attn.py b/vllm/v1/attention/backends/triton_attn.py index 09c36043c8c8..d051a89f03bb 100644 --- a/vllm/v1/attention/backends/triton_attn.py +++ b/vllm/v1/attention/backends/triton_attn.py @@ -154,7 +154,6 @@ class TritonAttentionBackend(AttentionBackend): torch.bfloat16, torch.float32, ] - supported_kernel_block_sizes: ClassVar[list[int | MultipleOf]] = [MultipleOf(16)] supported_kv_cache_dtypes: ClassVar[list[CacheDType]] = [ "auto", "fp8", @@ -162,6 +161,10 @@ class TritonAttentionBackend(AttentionBackend): "fp8_e5m2", ] + @staticmethod + def get_supported_kernel_block_sizes() -> list[int | MultipleOf]: + return [MultipleOf(16)] + @staticmethod def get_name() -> str: return "TRITON_ATTN" diff --git a/vllm/v1/attention/backends/xformers.py b/vllm/v1/attention/backends/xformers.py index d15d79417cc6..5039c44b9c3e 100644 --- a/vllm/v1/attention/backends/xformers.py +++ b/vllm/v1/attention/backends/xformers.py @@ -42,7 +42,10 @@ logger = init_logger(__name__) class XFormersAttentionBackend(AttentionBackend): accept_output_buffer: bool = True supported_dtypes: ClassVar[list[torch.dtype]] = [torch.float16, torch.bfloat16] - supported_kernel_block_sizes: ClassVar[list[int | MultipleOf]] = [MultipleOf(16)] + + @staticmethod + def get_supported_kernel_block_sizes() -> list[int | MultipleOf]: + return [MultipleOf(16)] @classmethod def get_supported_head_sizes(cls) -> list[int]: diff --git a/vllm/v1/worker/gpu_model_runner.py b/vllm/v1/worker/gpu_model_runner.py index e786cd8bc7c9..298bb1ef5f6f 100644 --- a/vllm/v1/worker/gpu_model_runner.py +++ b/vllm/v1/worker/gpu_model_runner.py @@ -4618,7 +4618,7 @@ class GPUModelRunner( """ for backend in backends: is_supported = False - for supported_size in backend.supported_kernel_block_sizes: + for supported_size in backend.get_supported_kernel_block_sizes(): if isinstance(supported_size, int): if block_size == supported_size: is_supported = True @@ -4649,7 +4649,7 @@ class GPUModelRunner( all_int_supported_sizes = set( supported_size for backend in backends - for supported_size in backend.supported_kernel_block_sizes + for supported_size in backend.get_supported_kernel_block_sizes() if isinstance(supported_size, int) )