mirror of
https://git.datalinker.icu/vllm-project/vllm.git
synced 2025-12-09 09:06:02 +08:00
[Attention] Refactor FA block_size limitations to hybrid models only (#29084)
Signed-off-by: NickLucche <nlucches@redhat.com>
This commit is contained in:
parent
5f7209a793
commit
066209a045
@ -61,7 +61,7 @@ for backend in BACKENDS_TO_TEST:
|
||||
|
||||
BACKEND_BLOCK_SIZES = {}
|
||||
for backend in BACKENDS_TO_TEST:
|
||||
supported_sizes = backend.get_class().supported_kernel_block_sizes
|
||||
supported_sizes = backend.get_class().get_supported_kernel_block_sizes()
|
||||
if supported_sizes:
|
||||
default_size = supported_sizes[0]
|
||||
block_size = (
|
||||
|
||||
@ -185,7 +185,9 @@ def _make_mock_backend_for_kernel_block_size(
|
||||
supported_sizes: list[int | MultipleOf],
|
||||
):
|
||||
class _MockBackend:
|
||||
supported_kernel_block_sizes = supported_sizes
|
||||
@staticmethod
|
||||
def get_supported_kernel_block_sizes():
|
||||
return supported_sizes
|
||||
|
||||
return _MockBackend()
|
||||
|
||||
|
||||
@ -46,9 +46,12 @@ class AttentionBackend(ABC):
|
||||
# makes sure the output tensor is allocated inside the cudagraph.
|
||||
accept_output_buffer: bool = False
|
||||
supported_dtypes: ClassVar[list[torch.dtype]] = [torch.float16, torch.bfloat16]
|
||||
supported_kernel_block_sizes: ClassVar[list[int | MultipleOf]] = [MultipleOf(1)]
|
||||
supported_kv_cache_dtypes: ClassVar[list["CacheDType"]] = ["auto"]
|
||||
|
||||
@staticmethod
|
||||
def get_supported_kernel_block_sizes() -> list[int | MultipleOf]:
|
||||
return [MultipleOf(1)]
|
||||
|
||||
@staticmethod
|
||||
@abstractmethod
|
||||
def get_name() -> str:
|
||||
@ -142,10 +145,11 @@ class AttentionBackend(ABC):
|
||||
if block_size not in valid_sizes:
|
||||
return False
|
||||
|
||||
if not cls.supported_kernel_block_sizes:
|
||||
supported_kernel_block_sizes = cls.get_supported_kernel_block_sizes()
|
||||
if not supported_kernel_block_sizes:
|
||||
return True
|
||||
|
||||
for supported_size in cls.supported_kernel_block_sizes:
|
||||
for supported_size in supported_kernel_block_sizes:
|
||||
if isinstance(supported_size, MultipleOf):
|
||||
supported_size = supported_size.base
|
||||
# With hybrid_blocks feature, the framework-level block size
|
||||
|
||||
@ -32,7 +32,7 @@ if is_flash_attn_varlen_func_available():
|
||||
get_scheduler_metadata,
|
||||
reshape_and_cache_flash,
|
||||
)
|
||||
from vllm.config import VllmConfig, get_layers_from_vllm_config
|
||||
from vllm.config import VllmConfig, get_current_vllm_config, get_layers_from_vllm_config
|
||||
from vllm.config.cache import CacheDType
|
||||
from vllm.distributed.parallel_state import get_dcp_group
|
||||
from vllm.logger import init_logger
|
||||
@ -56,11 +56,26 @@ logger = init_logger(__name__)
|
||||
class FlashAttentionBackend(AttentionBackend):
|
||||
accept_output_buffer: bool = True
|
||||
supported_dtypes: ClassVar[list[torch.dtype]] = [torch.float16, torch.bfloat16]
|
||||
# NOTE(tdoublep): while in principle, FA supports
|
||||
# MultipleOf(16), these are the block sizes that do not
|
||||
# suffer from the NaN propagation problem described here:
|
||||
# https://github.com/Dao-AILab/flash-attention/issues/1974
|
||||
supported_kernel_block_sizes: ClassVar[list[int | MultipleOf]] = [16, 32, 64]
|
||||
|
||||
@staticmethod
|
||||
def get_supported_kernel_block_sizes() -> list[int | MultipleOf]:
|
||||
vllm_config = get_current_vllm_config()
|
||||
model_config = vllm_config.model_config
|
||||
cache_config = vllm_config.cache_config
|
||||
if (
|
||||
model_config
|
||||
and model_config.is_hybrid
|
||||
and (
|
||||
cache_config.mamba_ssm_cache_dtype == "float32"
|
||||
or cache_config.mamba_cache_dtype == "float32"
|
||||
)
|
||||
):
|
||||
# NOTE(tdoublep): while in principle, FA supports
|
||||
# MultipleOf(16), these are the block sizes that do not
|
||||
# suffer from the NaN propagation problem described here:
|
||||
# https://github.com/Dao-AILab/flash-attention/issues/1974
|
||||
return [16, 32, 64]
|
||||
return [MultipleOf(16)]
|
||||
|
||||
@staticmethod
|
||||
def get_name() -> str:
|
||||
|
||||
@ -16,7 +16,6 @@ from flashinfer import (
|
||||
from flashinfer.decode import _get_range_buf, trtllm_batch_decode_with_kv_cache
|
||||
from flashinfer.prefill import trtllm_batch_context_with_kv_cache
|
||||
from flashinfer.utils import FP4Tensor
|
||||
from typing_extensions import override
|
||||
|
||||
from vllm import envs
|
||||
from vllm.attention.backends.abstract import (
|
||||
@ -275,10 +274,6 @@ class BatchDCPPrefillWrapper:
|
||||
class FlashInferBackend(AttentionBackend):
|
||||
accept_output_buffer: bool = True
|
||||
supported_dtypes: ClassVar[list[torch.dtype]] = [torch.float16, torch.bfloat16]
|
||||
# Note: Not sure for all platforms,
|
||||
# but on Blackwell, only support a page size of
|
||||
# 16, 32, 64
|
||||
supported_kernel_block_sizes: ClassVar[list[int | MultipleOf]] = [16, 32, 64]
|
||||
supported_kv_cache_dtypes: ClassVar[list[CacheDType]] = [
|
||||
"auto",
|
||||
"fp8",
|
||||
@ -286,6 +281,12 @@ class FlashInferBackend(AttentionBackend):
|
||||
"fp8_e5m2",
|
||||
]
|
||||
|
||||
@staticmethod
|
||||
def get_supported_kernel_block_sizes() -> list[int | MultipleOf]:
|
||||
# Note: Not sure for all platforms, but on Blackwell,
|
||||
# only support a page size of 16, 32, 64.
|
||||
return [16, 32, 64]
|
||||
|
||||
@staticmethod
|
||||
def get_name() -> str:
|
||||
return "FLASHINFER"
|
||||
@ -566,7 +567,6 @@ class FlashInferMetadataBuilder(AttentionMetadataBuilder[FlashInferMetadata]):
|
||||
)
|
||||
|
||||
@classmethod
|
||||
@override
|
||||
def get_cudagraph_support(
|
||||
cls: type["FlashInferMetadataBuilder"],
|
||||
vllm_config: VllmConfig,
|
||||
|
||||
@ -36,13 +36,16 @@ class CutlassMLAMetadataBuilder(MLACommonMetadataBuilder[MLACommonMetadata]):
|
||||
|
||||
class CutlassMLABackend(MLACommonBackend):
|
||||
supported_dtypes: ClassVar[list[torch.dtype]] = [torch.float16, torch.bfloat16]
|
||||
supported_kernel_block_sizes: ClassVar[list[int | MultipleOf]] = [128]
|
||||
supported_kv_cache_dtypes: ClassVar[list[CacheDType]] = [
|
||||
"auto",
|
||||
"fp8",
|
||||
"fp8_e4m3",
|
||||
]
|
||||
|
||||
@staticmethod
|
||||
def get_supported_kernel_block_sizes() -> list[int | MultipleOf]:
|
||||
return [128]
|
||||
|
||||
@staticmethod
|
||||
def get_name() -> str:
|
||||
return "CUTLASS_MLA"
|
||||
|
||||
@ -41,9 +41,12 @@ logger = init_logger(__name__)
|
||||
|
||||
class FlashAttnMLABackend(MLACommonBackend):
|
||||
supported_dtypes: ClassVar[list[torch.dtype]] = [torch.float16, torch.bfloat16]
|
||||
supported_kernel_block_sizes: ClassVar[list[int | MultipleOf]] = [MultipleOf(16)]
|
||||
supported_kv_cache_dtypes: ClassVar[list[CacheDType]] = ["auto"]
|
||||
|
||||
@staticmethod
|
||||
def get_supported_kernel_block_sizes() -> list[int | MultipleOf]:
|
||||
return [MultipleOf(16)]
|
||||
|
||||
@staticmethod
|
||||
def get_name() -> str:
|
||||
return "FLASH_ATTN_MLA"
|
||||
|
||||
@ -35,13 +35,16 @@ class FlashInferMLAMetadataBuilder(MLACommonMetadataBuilder[MLACommonMetadata]):
|
||||
|
||||
class FlashInferMLABackend(MLACommonBackend):
|
||||
supported_dtypes: ClassVar[list[torch.dtype]] = [torch.float16, torch.bfloat16]
|
||||
supported_kernel_block_sizes: ClassVar[list[int | MultipleOf]] = [32, 64]
|
||||
supported_kv_cache_dtypes: ClassVar[list[CacheDType]] = [
|
||||
"auto",
|
||||
"fp8",
|
||||
"fp8_e4m3",
|
||||
]
|
||||
|
||||
@staticmethod
|
||||
def get_supported_kernel_block_sizes() -> list[int | MultipleOf]:
|
||||
return [32, 64]
|
||||
|
||||
@staticmethod
|
||||
def get_name() -> str:
|
||||
return "FLASHINFER_MLA"
|
||||
|
||||
@ -39,13 +39,16 @@ logger = init_logger(__name__)
|
||||
|
||||
class FlashMLABackend(MLACommonBackend):
|
||||
supported_dtypes: ClassVar[list[torch.dtype]] = [torch.float16, torch.bfloat16]
|
||||
supported_kernel_block_sizes: ClassVar[list[int | MultipleOf]] = [64]
|
||||
supported_kv_cache_dtypes: ClassVar[list[CacheDType]] = [
|
||||
"auto",
|
||||
"fp8",
|
||||
"fp8_e4m3",
|
||||
]
|
||||
|
||||
@staticmethod
|
||||
def get_supported_kernel_block_sizes() -> list[int | MultipleOf]:
|
||||
return [64]
|
||||
|
||||
@staticmethod
|
||||
def get_name() -> str:
|
||||
return "FLASHMLA"
|
||||
|
||||
@ -55,9 +55,12 @@ structured as:
|
||||
class FlashMLASparseBackend(AttentionBackend):
|
||||
accept_output_buffer: bool = True
|
||||
supported_dtypes: ClassVar[list[torch.dtype]] = [torch.bfloat16]
|
||||
supported_kernel_block_sizes: ClassVar[list[int | MultipleOf]] = [64]
|
||||
supported_kv_cache_dtypes: ClassVar[list[CacheDType]] = ["auto", "fp8_ds_mla"]
|
||||
|
||||
@staticmethod
|
||||
def get_supported_kernel_block_sizes() -> list[int | MultipleOf]:
|
||||
return [64]
|
||||
|
||||
@staticmethod
|
||||
def get_name() -> str:
|
||||
return "FLASHMLA_SPARSE"
|
||||
|
||||
@ -24,9 +24,9 @@ logger = init_logger(__name__)
|
||||
|
||||
|
||||
class DeepseekV32IndexerBackend(AttentionBackend):
|
||||
supported_kernel_block_sizes: ClassVar[list[int | MultipleOf]] = [
|
||||
1 if current_platform.is_rocm() else 64
|
||||
]
|
||||
@staticmethod
|
||||
def get_supported_kernel_block_sizes() -> list[int | MultipleOf]:
|
||||
return [1 if current_platform.is_rocm() else 64]
|
||||
|
||||
@classmethod
|
||||
def get_supported_head_sizes(cls) -> list[int]:
|
||||
|
||||
@ -21,7 +21,9 @@ from vllm.v1.kv_cache_interface import AttentionSpec
|
||||
|
||||
|
||||
class AiterMLABackend(MLACommonBackend):
|
||||
supported_kernel_block_sizes: ClassVar[list[int | MultipleOf]] = [1]
|
||||
@staticmethod
|
||||
def get_supported_kernel_block_sizes() -> list[int | MultipleOf]:
|
||||
return [1]
|
||||
|
||||
@staticmethod
|
||||
def get_name() -> str:
|
||||
|
||||
@ -447,7 +447,10 @@ class AiterFlashAttentionMetadataBuilder(
|
||||
class AiterFlashAttentionBackend(AttentionBackend):
|
||||
accept_output_buffer: bool = True
|
||||
supported_dtypes: ClassVar[list[torch.dtype]] = [torch.float16, torch.bfloat16]
|
||||
supported_kernel_block_sizes: ClassVar[list[int | MultipleOf]] = [MultipleOf(16)]
|
||||
|
||||
@staticmethod
|
||||
def get_supported_kernel_block_sizes() -> list[int | MultipleOf]:
|
||||
return [MultipleOf(16)]
|
||||
|
||||
@classmethod
|
||||
def get_supported_head_sizes(cls) -> list[int]:
|
||||
|
||||
@ -31,7 +31,10 @@ logger = init_logger(__name__)
|
||||
class TreeAttentionBackend(AttentionBackend):
|
||||
accept_output_buffer: bool = True
|
||||
supported_dtypes: ClassVar[list[torch.dtype]] = [torch.float16, torch.bfloat16]
|
||||
supported_kernel_block_sizes: ClassVar[list[int | MultipleOf]] = [MultipleOf(16)]
|
||||
|
||||
@staticmethod
|
||||
def get_supported_kernel_block_sizes() -> list[int | MultipleOf]:
|
||||
return [MultipleOf(16)]
|
||||
|
||||
@classmethod
|
||||
def get_supported_head_sizes(cls) -> list[int]:
|
||||
|
||||
@ -154,7 +154,6 @@ class TritonAttentionBackend(AttentionBackend):
|
||||
torch.bfloat16,
|
||||
torch.float32,
|
||||
]
|
||||
supported_kernel_block_sizes: ClassVar[list[int | MultipleOf]] = [MultipleOf(16)]
|
||||
supported_kv_cache_dtypes: ClassVar[list[CacheDType]] = [
|
||||
"auto",
|
||||
"fp8",
|
||||
@ -162,6 +161,10 @@ class TritonAttentionBackend(AttentionBackend):
|
||||
"fp8_e5m2",
|
||||
]
|
||||
|
||||
@staticmethod
|
||||
def get_supported_kernel_block_sizes() -> list[int | MultipleOf]:
|
||||
return [MultipleOf(16)]
|
||||
|
||||
@staticmethod
|
||||
def get_name() -> str:
|
||||
return "TRITON_ATTN"
|
||||
|
||||
@ -42,7 +42,10 @@ logger = init_logger(__name__)
|
||||
class XFormersAttentionBackend(AttentionBackend):
|
||||
accept_output_buffer: bool = True
|
||||
supported_dtypes: ClassVar[list[torch.dtype]] = [torch.float16, torch.bfloat16]
|
||||
supported_kernel_block_sizes: ClassVar[list[int | MultipleOf]] = [MultipleOf(16)]
|
||||
|
||||
@staticmethod
|
||||
def get_supported_kernel_block_sizes() -> list[int | MultipleOf]:
|
||||
return [MultipleOf(16)]
|
||||
|
||||
@classmethod
|
||||
def get_supported_head_sizes(cls) -> list[int]:
|
||||
|
||||
@ -4618,7 +4618,7 @@ class GPUModelRunner(
|
||||
"""
|
||||
for backend in backends:
|
||||
is_supported = False
|
||||
for supported_size in backend.supported_kernel_block_sizes:
|
||||
for supported_size in backend.get_supported_kernel_block_sizes():
|
||||
if isinstance(supported_size, int):
|
||||
if block_size == supported_size:
|
||||
is_supported = True
|
||||
@ -4649,7 +4649,7 @@ class GPUModelRunner(
|
||||
all_int_supported_sizes = set(
|
||||
supported_size
|
||||
for backend in backends
|
||||
for supported_size in backend.supported_kernel_block_sizes
|
||||
for supported_size in backend.get_supported_kernel_block_sizes()
|
||||
if isinstance(supported_size, int)
|
||||
)
|
||||
|
||||
|
||||
Loading…
x
Reference in New Issue
Block a user