mirror of
https://git.datalinker.icu/vllm-project/vllm.git
synced 2025-12-11 04:34:54 +08:00
[Attention] Refactor FA block_size limitations to hybrid models only (#29084)
Signed-off-by: NickLucche <nlucches@redhat.com>
This commit is contained in:
parent
5f7209a793
commit
066209a045
@ -61,7 +61,7 @@ for backend in BACKENDS_TO_TEST:
|
|||||||
|
|
||||||
BACKEND_BLOCK_SIZES = {}
|
BACKEND_BLOCK_SIZES = {}
|
||||||
for backend in BACKENDS_TO_TEST:
|
for backend in BACKENDS_TO_TEST:
|
||||||
supported_sizes = backend.get_class().supported_kernel_block_sizes
|
supported_sizes = backend.get_class().get_supported_kernel_block_sizes()
|
||||||
if supported_sizes:
|
if supported_sizes:
|
||||||
default_size = supported_sizes[0]
|
default_size = supported_sizes[0]
|
||||||
block_size = (
|
block_size = (
|
||||||
|
|||||||
@ -185,7 +185,9 @@ def _make_mock_backend_for_kernel_block_size(
|
|||||||
supported_sizes: list[int | MultipleOf],
|
supported_sizes: list[int | MultipleOf],
|
||||||
):
|
):
|
||||||
class _MockBackend:
|
class _MockBackend:
|
||||||
supported_kernel_block_sizes = supported_sizes
|
@staticmethod
|
||||||
|
def get_supported_kernel_block_sizes():
|
||||||
|
return supported_sizes
|
||||||
|
|
||||||
return _MockBackend()
|
return _MockBackend()
|
||||||
|
|
||||||
|
|||||||
@ -46,9 +46,12 @@ class AttentionBackend(ABC):
|
|||||||
# makes sure the output tensor is allocated inside the cudagraph.
|
# makes sure the output tensor is allocated inside the cudagraph.
|
||||||
accept_output_buffer: bool = False
|
accept_output_buffer: bool = False
|
||||||
supported_dtypes: ClassVar[list[torch.dtype]] = [torch.float16, torch.bfloat16]
|
supported_dtypes: ClassVar[list[torch.dtype]] = [torch.float16, torch.bfloat16]
|
||||||
supported_kernel_block_sizes: ClassVar[list[int | MultipleOf]] = [MultipleOf(1)]
|
|
||||||
supported_kv_cache_dtypes: ClassVar[list["CacheDType"]] = ["auto"]
|
supported_kv_cache_dtypes: ClassVar[list["CacheDType"]] = ["auto"]
|
||||||
|
|
||||||
|
@staticmethod
|
||||||
|
def get_supported_kernel_block_sizes() -> list[int | MultipleOf]:
|
||||||
|
return [MultipleOf(1)]
|
||||||
|
|
||||||
@staticmethod
|
@staticmethod
|
||||||
@abstractmethod
|
@abstractmethod
|
||||||
def get_name() -> str:
|
def get_name() -> str:
|
||||||
@ -142,10 +145,11 @@ class AttentionBackend(ABC):
|
|||||||
if block_size not in valid_sizes:
|
if block_size not in valid_sizes:
|
||||||
return False
|
return False
|
||||||
|
|
||||||
if not cls.supported_kernel_block_sizes:
|
supported_kernel_block_sizes = cls.get_supported_kernel_block_sizes()
|
||||||
|
if not supported_kernel_block_sizes:
|
||||||
return True
|
return True
|
||||||
|
|
||||||
for supported_size in cls.supported_kernel_block_sizes:
|
for supported_size in supported_kernel_block_sizes:
|
||||||
if isinstance(supported_size, MultipleOf):
|
if isinstance(supported_size, MultipleOf):
|
||||||
supported_size = supported_size.base
|
supported_size = supported_size.base
|
||||||
# With hybrid_blocks feature, the framework-level block size
|
# With hybrid_blocks feature, the framework-level block size
|
||||||
|
|||||||
@ -32,7 +32,7 @@ if is_flash_attn_varlen_func_available():
|
|||||||
get_scheduler_metadata,
|
get_scheduler_metadata,
|
||||||
reshape_and_cache_flash,
|
reshape_and_cache_flash,
|
||||||
)
|
)
|
||||||
from vllm.config import VllmConfig, get_layers_from_vllm_config
|
from vllm.config import VllmConfig, get_current_vllm_config, get_layers_from_vllm_config
|
||||||
from vllm.config.cache import CacheDType
|
from vllm.config.cache import CacheDType
|
||||||
from vllm.distributed.parallel_state import get_dcp_group
|
from vllm.distributed.parallel_state import get_dcp_group
|
||||||
from vllm.logger import init_logger
|
from vllm.logger import init_logger
|
||||||
@ -56,11 +56,26 @@ logger = init_logger(__name__)
|
|||||||
class FlashAttentionBackend(AttentionBackend):
|
class FlashAttentionBackend(AttentionBackend):
|
||||||
accept_output_buffer: bool = True
|
accept_output_buffer: bool = True
|
||||||
supported_dtypes: ClassVar[list[torch.dtype]] = [torch.float16, torch.bfloat16]
|
supported_dtypes: ClassVar[list[torch.dtype]] = [torch.float16, torch.bfloat16]
|
||||||
# NOTE(tdoublep): while in principle, FA supports
|
|
||||||
# MultipleOf(16), these are the block sizes that do not
|
@staticmethod
|
||||||
# suffer from the NaN propagation problem described here:
|
def get_supported_kernel_block_sizes() -> list[int | MultipleOf]:
|
||||||
# https://github.com/Dao-AILab/flash-attention/issues/1974
|
vllm_config = get_current_vllm_config()
|
||||||
supported_kernel_block_sizes: ClassVar[list[int | MultipleOf]] = [16, 32, 64]
|
model_config = vllm_config.model_config
|
||||||
|
cache_config = vllm_config.cache_config
|
||||||
|
if (
|
||||||
|
model_config
|
||||||
|
and model_config.is_hybrid
|
||||||
|
and (
|
||||||
|
cache_config.mamba_ssm_cache_dtype == "float32"
|
||||||
|
or cache_config.mamba_cache_dtype == "float32"
|
||||||
|
)
|
||||||
|
):
|
||||||
|
# NOTE(tdoublep): while in principle, FA supports
|
||||||
|
# MultipleOf(16), these are the block sizes that do not
|
||||||
|
# suffer from the NaN propagation problem described here:
|
||||||
|
# https://github.com/Dao-AILab/flash-attention/issues/1974
|
||||||
|
return [16, 32, 64]
|
||||||
|
return [MultipleOf(16)]
|
||||||
|
|
||||||
@staticmethod
|
@staticmethod
|
||||||
def get_name() -> str:
|
def get_name() -> str:
|
||||||
|
|||||||
@ -16,7 +16,6 @@ from flashinfer import (
|
|||||||
from flashinfer.decode import _get_range_buf, trtllm_batch_decode_with_kv_cache
|
from flashinfer.decode import _get_range_buf, trtllm_batch_decode_with_kv_cache
|
||||||
from flashinfer.prefill import trtllm_batch_context_with_kv_cache
|
from flashinfer.prefill import trtllm_batch_context_with_kv_cache
|
||||||
from flashinfer.utils import FP4Tensor
|
from flashinfer.utils import FP4Tensor
|
||||||
from typing_extensions import override
|
|
||||||
|
|
||||||
from vllm import envs
|
from vllm import envs
|
||||||
from vllm.attention.backends.abstract import (
|
from vllm.attention.backends.abstract import (
|
||||||
@ -275,10 +274,6 @@ class BatchDCPPrefillWrapper:
|
|||||||
class FlashInferBackend(AttentionBackend):
|
class FlashInferBackend(AttentionBackend):
|
||||||
accept_output_buffer: bool = True
|
accept_output_buffer: bool = True
|
||||||
supported_dtypes: ClassVar[list[torch.dtype]] = [torch.float16, torch.bfloat16]
|
supported_dtypes: ClassVar[list[torch.dtype]] = [torch.float16, torch.bfloat16]
|
||||||
# Note: Not sure for all platforms,
|
|
||||||
# but on Blackwell, only support a page size of
|
|
||||||
# 16, 32, 64
|
|
||||||
supported_kernel_block_sizes: ClassVar[list[int | MultipleOf]] = [16, 32, 64]
|
|
||||||
supported_kv_cache_dtypes: ClassVar[list[CacheDType]] = [
|
supported_kv_cache_dtypes: ClassVar[list[CacheDType]] = [
|
||||||
"auto",
|
"auto",
|
||||||
"fp8",
|
"fp8",
|
||||||
@ -286,6 +281,12 @@ class FlashInferBackend(AttentionBackend):
|
|||||||
"fp8_e5m2",
|
"fp8_e5m2",
|
||||||
]
|
]
|
||||||
|
|
||||||
|
@staticmethod
|
||||||
|
def get_supported_kernel_block_sizes() -> list[int | MultipleOf]:
|
||||||
|
# Note: Not sure for all platforms, but on Blackwell,
|
||||||
|
# only support a page size of 16, 32, 64.
|
||||||
|
return [16, 32, 64]
|
||||||
|
|
||||||
@staticmethod
|
@staticmethod
|
||||||
def get_name() -> str:
|
def get_name() -> str:
|
||||||
return "FLASHINFER"
|
return "FLASHINFER"
|
||||||
@ -566,7 +567,6 @@ class FlashInferMetadataBuilder(AttentionMetadataBuilder[FlashInferMetadata]):
|
|||||||
)
|
)
|
||||||
|
|
||||||
@classmethod
|
@classmethod
|
||||||
@override
|
|
||||||
def get_cudagraph_support(
|
def get_cudagraph_support(
|
||||||
cls: type["FlashInferMetadataBuilder"],
|
cls: type["FlashInferMetadataBuilder"],
|
||||||
vllm_config: VllmConfig,
|
vllm_config: VllmConfig,
|
||||||
|
|||||||
@ -36,13 +36,16 @@ class CutlassMLAMetadataBuilder(MLACommonMetadataBuilder[MLACommonMetadata]):
|
|||||||
|
|
||||||
class CutlassMLABackend(MLACommonBackend):
|
class CutlassMLABackend(MLACommonBackend):
|
||||||
supported_dtypes: ClassVar[list[torch.dtype]] = [torch.float16, torch.bfloat16]
|
supported_dtypes: ClassVar[list[torch.dtype]] = [torch.float16, torch.bfloat16]
|
||||||
supported_kernel_block_sizes: ClassVar[list[int | MultipleOf]] = [128]
|
|
||||||
supported_kv_cache_dtypes: ClassVar[list[CacheDType]] = [
|
supported_kv_cache_dtypes: ClassVar[list[CacheDType]] = [
|
||||||
"auto",
|
"auto",
|
||||||
"fp8",
|
"fp8",
|
||||||
"fp8_e4m3",
|
"fp8_e4m3",
|
||||||
]
|
]
|
||||||
|
|
||||||
|
@staticmethod
|
||||||
|
def get_supported_kernel_block_sizes() -> list[int | MultipleOf]:
|
||||||
|
return [128]
|
||||||
|
|
||||||
@staticmethod
|
@staticmethod
|
||||||
def get_name() -> str:
|
def get_name() -> str:
|
||||||
return "CUTLASS_MLA"
|
return "CUTLASS_MLA"
|
||||||
|
|||||||
@ -41,9 +41,12 @@ logger = init_logger(__name__)
|
|||||||
|
|
||||||
class FlashAttnMLABackend(MLACommonBackend):
|
class FlashAttnMLABackend(MLACommonBackend):
|
||||||
supported_dtypes: ClassVar[list[torch.dtype]] = [torch.float16, torch.bfloat16]
|
supported_dtypes: ClassVar[list[torch.dtype]] = [torch.float16, torch.bfloat16]
|
||||||
supported_kernel_block_sizes: ClassVar[list[int | MultipleOf]] = [MultipleOf(16)]
|
|
||||||
supported_kv_cache_dtypes: ClassVar[list[CacheDType]] = ["auto"]
|
supported_kv_cache_dtypes: ClassVar[list[CacheDType]] = ["auto"]
|
||||||
|
|
||||||
|
@staticmethod
|
||||||
|
def get_supported_kernel_block_sizes() -> list[int | MultipleOf]:
|
||||||
|
return [MultipleOf(16)]
|
||||||
|
|
||||||
@staticmethod
|
@staticmethod
|
||||||
def get_name() -> str:
|
def get_name() -> str:
|
||||||
return "FLASH_ATTN_MLA"
|
return "FLASH_ATTN_MLA"
|
||||||
|
|||||||
@ -35,13 +35,16 @@ class FlashInferMLAMetadataBuilder(MLACommonMetadataBuilder[MLACommonMetadata]):
|
|||||||
|
|
||||||
class FlashInferMLABackend(MLACommonBackend):
|
class FlashInferMLABackend(MLACommonBackend):
|
||||||
supported_dtypes: ClassVar[list[torch.dtype]] = [torch.float16, torch.bfloat16]
|
supported_dtypes: ClassVar[list[torch.dtype]] = [torch.float16, torch.bfloat16]
|
||||||
supported_kernel_block_sizes: ClassVar[list[int | MultipleOf]] = [32, 64]
|
|
||||||
supported_kv_cache_dtypes: ClassVar[list[CacheDType]] = [
|
supported_kv_cache_dtypes: ClassVar[list[CacheDType]] = [
|
||||||
"auto",
|
"auto",
|
||||||
"fp8",
|
"fp8",
|
||||||
"fp8_e4m3",
|
"fp8_e4m3",
|
||||||
]
|
]
|
||||||
|
|
||||||
|
@staticmethod
|
||||||
|
def get_supported_kernel_block_sizes() -> list[int | MultipleOf]:
|
||||||
|
return [32, 64]
|
||||||
|
|
||||||
@staticmethod
|
@staticmethod
|
||||||
def get_name() -> str:
|
def get_name() -> str:
|
||||||
return "FLASHINFER_MLA"
|
return "FLASHINFER_MLA"
|
||||||
|
|||||||
@ -39,13 +39,16 @@ logger = init_logger(__name__)
|
|||||||
|
|
||||||
class FlashMLABackend(MLACommonBackend):
|
class FlashMLABackend(MLACommonBackend):
|
||||||
supported_dtypes: ClassVar[list[torch.dtype]] = [torch.float16, torch.bfloat16]
|
supported_dtypes: ClassVar[list[torch.dtype]] = [torch.float16, torch.bfloat16]
|
||||||
supported_kernel_block_sizes: ClassVar[list[int | MultipleOf]] = [64]
|
|
||||||
supported_kv_cache_dtypes: ClassVar[list[CacheDType]] = [
|
supported_kv_cache_dtypes: ClassVar[list[CacheDType]] = [
|
||||||
"auto",
|
"auto",
|
||||||
"fp8",
|
"fp8",
|
||||||
"fp8_e4m3",
|
"fp8_e4m3",
|
||||||
]
|
]
|
||||||
|
|
||||||
|
@staticmethod
|
||||||
|
def get_supported_kernel_block_sizes() -> list[int | MultipleOf]:
|
||||||
|
return [64]
|
||||||
|
|
||||||
@staticmethod
|
@staticmethod
|
||||||
def get_name() -> str:
|
def get_name() -> str:
|
||||||
return "FLASHMLA"
|
return "FLASHMLA"
|
||||||
|
|||||||
@ -55,9 +55,12 @@ structured as:
|
|||||||
class FlashMLASparseBackend(AttentionBackend):
|
class FlashMLASparseBackend(AttentionBackend):
|
||||||
accept_output_buffer: bool = True
|
accept_output_buffer: bool = True
|
||||||
supported_dtypes: ClassVar[list[torch.dtype]] = [torch.bfloat16]
|
supported_dtypes: ClassVar[list[torch.dtype]] = [torch.bfloat16]
|
||||||
supported_kernel_block_sizes: ClassVar[list[int | MultipleOf]] = [64]
|
|
||||||
supported_kv_cache_dtypes: ClassVar[list[CacheDType]] = ["auto", "fp8_ds_mla"]
|
supported_kv_cache_dtypes: ClassVar[list[CacheDType]] = ["auto", "fp8_ds_mla"]
|
||||||
|
|
||||||
|
@staticmethod
|
||||||
|
def get_supported_kernel_block_sizes() -> list[int | MultipleOf]:
|
||||||
|
return [64]
|
||||||
|
|
||||||
@staticmethod
|
@staticmethod
|
||||||
def get_name() -> str:
|
def get_name() -> str:
|
||||||
return "FLASHMLA_SPARSE"
|
return "FLASHMLA_SPARSE"
|
||||||
|
|||||||
@ -24,9 +24,9 @@ logger = init_logger(__name__)
|
|||||||
|
|
||||||
|
|
||||||
class DeepseekV32IndexerBackend(AttentionBackend):
|
class DeepseekV32IndexerBackend(AttentionBackend):
|
||||||
supported_kernel_block_sizes: ClassVar[list[int | MultipleOf]] = [
|
@staticmethod
|
||||||
1 if current_platform.is_rocm() else 64
|
def get_supported_kernel_block_sizes() -> list[int | MultipleOf]:
|
||||||
]
|
return [1 if current_platform.is_rocm() else 64]
|
||||||
|
|
||||||
@classmethod
|
@classmethod
|
||||||
def get_supported_head_sizes(cls) -> list[int]:
|
def get_supported_head_sizes(cls) -> list[int]:
|
||||||
|
|||||||
@ -21,7 +21,9 @@ from vllm.v1.kv_cache_interface import AttentionSpec
|
|||||||
|
|
||||||
|
|
||||||
class AiterMLABackend(MLACommonBackend):
|
class AiterMLABackend(MLACommonBackend):
|
||||||
supported_kernel_block_sizes: ClassVar[list[int | MultipleOf]] = [1]
|
@staticmethod
|
||||||
|
def get_supported_kernel_block_sizes() -> list[int | MultipleOf]:
|
||||||
|
return [1]
|
||||||
|
|
||||||
@staticmethod
|
@staticmethod
|
||||||
def get_name() -> str:
|
def get_name() -> str:
|
||||||
|
|||||||
@ -447,7 +447,10 @@ class AiterFlashAttentionMetadataBuilder(
|
|||||||
class AiterFlashAttentionBackend(AttentionBackend):
|
class AiterFlashAttentionBackend(AttentionBackend):
|
||||||
accept_output_buffer: bool = True
|
accept_output_buffer: bool = True
|
||||||
supported_dtypes: ClassVar[list[torch.dtype]] = [torch.float16, torch.bfloat16]
|
supported_dtypes: ClassVar[list[torch.dtype]] = [torch.float16, torch.bfloat16]
|
||||||
supported_kernel_block_sizes: ClassVar[list[int | MultipleOf]] = [MultipleOf(16)]
|
|
||||||
|
@staticmethod
|
||||||
|
def get_supported_kernel_block_sizes() -> list[int | MultipleOf]:
|
||||||
|
return [MultipleOf(16)]
|
||||||
|
|
||||||
@classmethod
|
@classmethod
|
||||||
def get_supported_head_sizes(cls) -> list[int]:
|
def get_supported_head_sizes(cls) -> list[int]:
|
||||||
|
|||||||
@ -31,7 +31,10 @@ logger = init_logger(__name__)
|
|||||||
class TreeAttentionBackend(AttentionBackend):
|
class TreeAttentionBackend(AttentionBackend):
|
||||||
accept_output_buffer: bool = True
|
accept_output_buffer: bool = True
|
||||||
supported_dtypes: ClassVar[list[torch.dtype]] = [torch.float16, torch.bfloat16]
|
supported_dtypes: ClassVar[list[torch.dtype]] = [torch.float16, torch.bfloat16]
|
||||||
supported_kernel_block_sizes: ClassVar[list[int | MultipleOf]] = [MultipleOf(16)]
|
|
||||||
|
@staticmethod
|
||||||
|
def get_supported_kernel_block_sizes() -> list[int | MultipleOf]:
|
||||||
|
return [MultipleOf(16)]
|
||||||
|
|
||||||
@classmethod
|
@classmethod
|
||||||
def get_supported_head_sizes(cls) -> list[int]:
|
def get_supported_head_sizes(cls) -> list[int]:
|
||||||
|
|||||||
@ -154,7 +154,6 @@ class TritonAttentionBackend(AttentionBackend):
|
|||||||
torch.bfloat16,
|
torch.bfloat16,
|
||||||
torch.float32,
|
torch.float32,
|
||||||
]
|
]
|
||||||
supported_kernel_block_sizes: ClassVar[list[int | MultipleOf]] = [MultipleOf(16)]
|
|
||||||
supported_kv_cache_dtypes: ClassVar[list[CacheDType]] = [
|
supported_kv_cache_dtypes: ClassVar[list[CacheDType]] = [
|
||||||
"auto",
|
"auto",
|
||||||
"fp8",
|
"fp8",
|
||||||
@ -162,6 +161,10 @@ class TritonAttentionBackend(AttentionBackend):
|
|||||||
"fp8_e5m2",
|
"fp8_e5m2",
|
||||||
]
|
]
|
||||||
|
|
||||||
|
@staticmethod
|
||||||
|
def get_supported_kernel_block_sizes() -> list[int | MultipleOf]:
|
||||||
|
return [MultipleOf(16)]
|
||||||
|
|
||||||
@staticmethod
|
@staticmethod
|
||||||
def get_name() -> str:
|
def get_name() -> str:
|
||||||
return "TRITON_ATTN"
|
return "TRITON_ATTN"
|
||||||
|
|||||||
@ -42,7 +42,10 @@ logger = init_logger(__name__)
|
|||||||
class XFormersAttentionBackend(AttentionBackend):
|
class XFormersAttentionBackend(AttentionBackend):
|
||||||
accept_output_buffer: bool = True
|
accept_output_buffer: bool = True
|
||||||
supported_dtypes: ClassVar[list[torch.dtype]] = [torch.float16, torch.bfloat16]
|
supported_dtypes: ClassVar[list[torch.dtype]] = [torch.float16, torch.bfloat16]
|
||||||
supported_kernel_block_sizes: ClassVar[list[int | MultipleOf]] = [MultipleOf(16)]
|
|
||||||
|
@staticmethod
|
||||||
|
def get_supported_kernel_block_sizes() -> list[int | MultipleOf]:
|
||||||
|
return [MultipleOf(16)]
|
||||||
|
|
||||||
@classmethod
|
@classmethod
|
||||||
def get_supported_head_sizes(cls) -> list[int]:
|
def get_supported_head_sizes(cls) -> list[int]:
|
||||||
|
|||||||
@ -4618,7 +4618,7 @@ class GPUModelRunner(
|
|||||||
"""
|
"""
|
||||||
for backend in backends:
|
for backend in backends:
|
||||||
is_supported = False
|
is_supported = False
|
||||||
for supported_size in backend.supported_kernel_block_sizes:
|
for supported_size in backend.get_supported_kernel_block_sizes():
|
||||||
if isinstance(supported_size, int):
|
if isinstance(supported_size, int):
|
||||||
if block_size == supported_size:
|
if block_size == supported_size:
|
||||||
is_supported = True
|
is_supported = True
|
||||||
@ -4649,7 +4649,7 @@ class GPUModelRunner(
|
|||||||
all_int_supported_sizes = set(
|
all_int_supported_sizes = set(
|
||||||
supported_size
|
supported_size
|
||||||
for backend in backends
|
for backend in backends
|
||||||
for supported_size in backend.supported_kernel_block_sizes
|
for supported_size in backend.get_supported_kernel_block_sizes()
|
||||||
if isinstance(supported_size, int)
|
if isinstance(supported_size, int)
|
||||||
)
|
)
|
||||||
|
|
||||||
|
|||||||
Loading…
x
Reference in New Issue
Block a user