[V1] Support any head size for FlexAttention backend (#20467)

Signed-off-by: DarkLight1337 <tlleungac@connect.ust.hk>
This commit is contained in:
Cyrus Leung 2025-07-07 00:54:36 +08:00 committed by GitHub
parent e202dd2736
commit 9fb52e523a
No known key found for this signature in database
GPG Key ID: B5690EEEBB952194
20 changed files with 202 additions and 118 deletions

View File

@ -107,10 +107,9 @@ fi
if [[ $commands == *" kernels/attention"* ]]; then if [[ $commands == *" kernels/attention"* ]]; then
commands="${commands} \ commands="${commands} \
--ignore=kernels/attention/stest_attention_selector.py \ --ignore=kernels/attention/test_attention_selector.py \
--ignore=kernels/attention/test_blocksparse_attention.py \ --ignore=kernels/attention/test_blocksparse_attention.py \
--ignore=kernels/attention/test_encoder_decoder_attn.py \ --ignore=kernels/attention/test_encoder_decoder_attn.py \
--ignore=kernels/attention/test_attention_selector.py \
--ignore=kernels/attention/test_flash_attn.py \ --ignore=kernels/attention/test_flash_attn.py \
--ignore=kernels/attention/test_flashinfer.py \ --ignore=kernels/attention/test_flashinfer.py \
--ignore=kernels/attention/test_prefix_prefill.py \ --ignore=kernels/attention/test_prefix_prefill.py \

View File

@ -626,9 +626,6 @@ Specified using `--task generate`.
!!! note !!! note
Only `InternVLChatModel` with Qwen2.5 text backbone (`OpenGVLab/InternVL3-2B`, `OpenGVLab/InternVL2.5-1B` etc) has video inputs support currently. Only `InternVLChatModel` with Qwen2.5 text backbone (`OpenGVLab/InternVL3-2B`, `OpenGVLab/InternVL2.5-1B` etc) has video inputs support currently.
!!! note
`h2oai/h2ovl-mississippi-2b` will be available in V1 once we support head size 80.
!!! note !!! note
To use `TIGER-Lab/Mantis-8B-siglip-llama3`, you have to pass `--hf_overrides '{"architectures": ["MantisForConditionalGeneration"]}'` when running vLLM. To use `TIGER-Lab/Mantis-8B-siglip-llama3`, you have to pass `--hf_overrides '{"architectures": ["MantisForConditionalGeneration"]}'` when running vLLM.
@ -671,11 +668,8 @@ Specified using `--task generate`.
Our PaliGemma implementations have the same problem as Gemma 3 (see above) for both V0 and V1. Our PaliGemma implementations have the same problem as Gemma 3 (see above) for both V0 and V1.
!!! note !!! note
To use Qwen2.5-Omni, you have to install Hugging Face Transformers library from source via For Qwen2.5-Omni, reading audio from video pre-processing (`--mm-processor-kwargs '{"use_audio_in_video": true}'`)
`pip install git+https://github.com/huggingface/transformers.git`. is currently supported on V0 (but not V1), because overlapping modalities is not yet supported in V1.
Read audio from video pre-processing is currently supported on V0 (but not V1), because overlapping modalities is not yet supported in V1.
`--mm-processor-kwargs '{"use_audio_in_video": true}'`.
#### Transcription #### Transcription

View File

@ -98,7 +98,7 @@ def run_blip2(questions: list[str], modality: str) -> ModelRequestData:
# See https://huggingface.co/Salesforce/blip2-opt-2.7b/discussions/15#64ff02f3f8cf9e4f5b038262 #noqa # See https://huggingface.co/Salesforce/blip2-opt-2.7b/discussions/15#64ff02f3f8cf9e4f5b038262 #noqa
prompts = [f"Question: {question} Answer:" for question in questions] prompts = [f"Question: {question} Answer:" for question in questions]
engine_args = EngineArgs( engine_args = EngineArgs(
model="Salesforce/blip2-opt-6.7b", model="Salesforce/blip2-opt-2.7b",
limit_mm_per_prompt={modality: 1}, limit_mm_per_prompt={modality: 1},
) )
@ -971,7 +971,7 @@ def run_pixtral_hf(questions: list[str], modality: str) -> ModelRequestData:
) )
# Qwen # Qwen-VL
def run_qwen_vl(questions: list[str], modality: str) -> ModelRequestData: def run_qwen_vl(questions: list[str], modality: str) -> ModelRequestData:
assert modality == "image" assert modality == "image"

View File

@ -172,7 +172,7 @@ def test_env(
expected = "FLASHINFER_VLLM_V1" if use_v1 else name expected = "FLASHINFER_VLLM_V1" if use_v1 else name
assert backend.get_name() == expected assert backend.get_name() == expected
else: else:
backend = get_attn_backend(16, backend = get_attn_backend(32,
torch.float16, torch.float16,
torch.float16, torch.float16,
block_size, block_size,
@ -181,6 +181,17 @@ def test_env(
expected = "FLASH_ATTN_VLLM_V1" if use_v1 else name expected = "FLASH_ATTN_VLLM_V1" if use_v1 else name
assert backend.get_name() == expected assert backend.get_name() == expected
if use_v1:
backend = get_attn_backend(16,
torch.float16,
torch.float16,
block_size,
False,
use_mla=use_mla)
assert backend.get_name() == "FLEX_ATTENTION", (
"Should fallback to FlexAttention if head size is "
"not supported by FlashAttention")
@pytest.mark.parametrize("device", ["cpu", "cuda"]) @pytest.mark.parametrize("device", ["cpu", "cuda"])
@pytest.mark.parametrize("use_v1", [True, False]) @pytest.mark.parametrize("use_v1", [True, False])

View File

@ -33,9 +33,6 @@ if current_platform.is_rocm():
os.environ["VLLM_USE_TRITON_FLASH_ATTN"] = "0" os.environ["VLLM_USE_TRITON_FLASH_ATTN"] = "0"
REQUIRES_V0_MODELS = [ REQUIRES_V0_MODELS = [
# V1 Test: no way to fall back for head_dim = 80
# https://github.com/vllm-project/vllm/issues/14524
"qwen_vl",
# V1 Test: not enough KV cache space in C1. # V1 Test: not enough KV cache space in C1.
"fuyu", "fuyu",
] ]
@ -221,8 +218,7 @@ VLM_TEST_SETTINGS = {
marks=[large_gpu_mark(min_gb=32)], marks=[large_gpu_mark(min_gb=32)],
), ),
"blip2": VLMTestInfo( "blip2": VLMTestInfo(
# TODO: Change back to 2.7b once head_dim = 80 is supported models=["Salesforce/blip2-opt-2.7b"],
models=["Salesforce/blip2-opt-6.7b"],
test_type=VLMTestType.IMAGE, test_type=VLMTestType.IMAGE,
prompt_formatter=lambda img_prompt: f"Question: {img_prompt} Answer:", prompt_formatter=lambda img_prompt: f"Question: {img_prompt} Answer:",
img_idx_to_prompt=lambda idx: "", img_idx_to_prompt=lambda idx: "",
@ -340,8 +336,7 @@ VLM_TEST_SETTINGS = {
"h2ovl": VLMTestInfo( "h2ovl": VLMTestInfo(
models = [ models = [
"h2oai/h2ovl-mississippi-800m", "h2oai/h2ovl-mississippi-800m",
# TODO: Re-enable once head_dim = 80 is supported "h2oai/h2ovl-mississippi-2b",
# "h2oai/h2ovl-mississippi-2b",
], ],
test_type=(VLMTestType.IMAGE, VLMTestType.MULTI_IMAGE), test_type=(VLMTestType.IMAGE, VLMTestType.MULTI_IMAGE),
prompt_formatter=lambda img_prompt: f"<|prompt|>{img_prompt}<|end|><|answer|>", # noqa: E501 prompt_formatter=lambda img_prompt: f"<|prompt|>{img_prompt}<|end|><|answer|>", # noqa: E501

View File

@ -83,7 +83,7 @@ MODELS = [
QWEN2_CONFIG, QWEN2_CONFIG,
PHI3_CONFIG, PHI3_CONFIG,
GPT2_CONFIG, GPT2_CONFIG,
# STABLELM_CONFIG, # enable this when v1 support head_size=80 STABLELM_CONFIG,
DOLPHIN_CONFIG, DOLPHIN_CONFIG,
# STARCODER_CONFIG, # broken # STARCODER_CONFIG, # broken
] ]

View File

@ -240,8 +240,9 @@ _TEXT_GENERATION_EXAMPLE_MODELS = {
"OrionForCausalLM": _HfExamplesInfo("OrionStarAI/Orion-14B-Chat", "OrionForCausalLM": _HfExamplesInfo("OrionStarAI/Orion-14B-Chat",
trust_remote_code=True), trust_remote_code=True),
"PersimmonForCausalLM": _HfExamplesInfo("adept/persimmon-8b-chat"), "PersimmonForCausalLM": _HfExamplesInfo("adept/persimmon-8b-chat"),
"PhiForCausalLM": _HfExamplesInfo("microsoft/phi-2", v0_only=True), "PhiForCausalLM": _HfExamplesInfo("microsoft/phi-2"),
"Phi3ForCausalLM": _HfExamplesInfo("microsoft/Phi-3-mini-4k-instruct"), "Phi3ForCausalLM": _HfExamplesInfo("microsoft/Phi-3-mini-4k-instruct"),
# Blocksparse attention not supported in V1 yet
"Phi3SmallForCausalLM": _HfExamplesInfo("microsoft/Phi-3-small-8k-instruct", "Phi3SmallForCausalLM": _HfExamplesInfo("microsoft/Phi-3-small-8k-instruct",
trust_remote_code=True, trust_remote_code=True,
v0_only=True), v0_only=True),
@ -258,10 +259,8 @@ _TEXT_GENERATION_EXAMPLE_MODELS = {
"Qwen3MoeForCausalLM": _HfExamplesInfo("Qwen/Qwen3-30B-A3B"), "Qwen3MoeForCausalLM": _HfExamplesInfo("Qwen/Qwen3-30B-A3B"),
"Qwen3ForSequenceClassification": _HfExamplesInfo("tomaarsen/Qwen3-Reranker-0.6B-seq-cls"), # noqa: E501 "Qwen3ForSequenceClassification": _HfExamplesInfo("tomaarsen/Qwen3-Reranker-0.6B-seq-cls"), # noqa: E501
"RWForCausalLM": _HfExamplesInfo("tiiuae/falcon-40b"), "RWForCausalLM": _HfExamplesInfo("tiiuae/falcon-40b"),
"StableLMEpochForCausalLM": _HfExamplesInfo("stabilityai/stablelm-zephyr-3b", # noqa: E501 "StableLMEpochForCausalLM": _HfExamplesInfo("stabilityai/stablelm-zephyr-3b"), # noqa: E501
v0_only=True), "StableLmForCausalLM": _HfExamplesInfo("stabilityai/stablelm-3b-4e1t"),
"StableLmForCausalLM": _HfExamplesInfo("stabilityai/stablelm-3b-4e1t",
v0_only=True),
"Starcoder2ForCausalLM": _HfExamplesInfo("bigcode/starcoder2-3b"), "Starcoder2ForCausalLM": _HfExamplesInfo("bigcode/starcoder2-3b"),
"SolarForCausalLM": _HfExamplesInfo("upstage/solar-pro-preview-instruct"), "SolarForCausalLM": _HfExamplesInfo("upstage/solar-pro-preview-instruct"),
"TeleChat2ForCausalLM": _HfExamplesInfo("Tele-AI/TeleChat2-3B", "TeleChat2ForCausalLM": _HfExamplesInfo("Tele-AI/TeleChat2-3B",
@ -330,8 +329,7 @@ _MULTIMODAL_EXAMPLE_MODELS = {
"AriaForConditionalGeneration": _HfExamplesInfo("rhymes-ai/Aria"), "AriaForConditionalGeneration": _HfExamplesInfo("rhymes-ai/Aria"),
"AyaVisionForConditionalGeneration": _HfExamplesInfo("CohereForAI/aya-vision-8b"), # noqa: E501 "AyaVisionForConditionalGeneration": _HfExamplesInfo("CohereForAI/aya-vision-8b"), # noqa: E501
"Blip2ForConditionalGeneration": _HfExamplesInfo("Salesforce/blip2-opt-2.7b", # noqa: E501 "Blip2ForConditionalGeneration": _HfExamplesInfo("Salesforce/blip2-opt-2.7b", # noqa: E501
extras={"6b": "Salesforce/blip2-opt-6.7b"}, # noqa: E501 extras={"6b": "Salesforce/blip2-opt-6.7b"}), # noqa: E501
v0_only=True),
"ChameleonForConditionalGeneration": _HfExamplesInfo("facebook/chameleon-7b"), # noqa: E501 "ChameleonForConditionalGeneration": _HfExamplesInfo("facebook/chameleon-7b"), # noqa: E501
"DeepseekVLV2ForCausalLM": _HfExamplesInfo("deepseek-ai/deepseek-vl2-tiny", # noqa: E501 "DeepseekVLV2ForCausalLM": _HfExamplesInfo("deepseek-ai/deepseek-vl2-tiny", # noqa: E501
extras={"fork": "Isotr0py/deepseek-vl2-tiny"}, # noqa: E501 extras={"fork": "Isotr0py/deepseek-vl2-tiny"}, # noqa: E501
@ -359,8 +357,7 @@ _MULTIMODAL_EXAMPLE_MODELS = {
trust_remote_code=True), trust_remote_code=True),
"KimiVLForConditionalGeneration": _HfExamplesInfo("moonshotai/Kimi-VL-A3B-Instruct", # noqa: E501 "KimiVLForConditionalGeneration": _HfExamplesInfo("moonshotai/Kimi-VL-A3B-Instruct", # noqa: E501
extras={"thinking": "moonshotai/Kimi-VL-A3B-Thinking"}, # noqa: E501 extras={"thinking": "moonshotai/Kimi-VL-A3B-Thinking"}, # noqa: E501
trust_remote_code=True, trust_remote_code=True),
v0_only=True),
"Llama4ForConditionalGeneration": _HfExamplesInfo("meta-llama/Llama-4-Scout-17B-16E-Instruct", # noqa: E501 "Llama4ForConditionalGeneration": _HfExamplesInfo("meta-llama/Llama-4-Scout-17B-16E-Instruct", # noqa: E501
max_model_len=10240), max_model_len=10240),
"LlavaForConditionalGeneration": _HfExamplesInfo("llava-hf/llava-1.5-7b-hf", "LlavaForConditionalGeneration": _HfExamplesInfo("llava-hf/llava-1.5-7b-hf",

View File

@ -22,7 +22,8 @@ def test_can_initialize(model_arch: str, monkeypatch: pytest.MonkeyPatch):
model_info.check_transformers_version(on_fail="skip") model_info.check_transformers_version(on_fail="skip")
# FIXME: Possible memory leak in the previous tests? # FIXME: Possible memory leak in the previous tests?
if model_arch == "GraniteSpeechForConditionalGeneration": if model_arch in ("GraniteSpeechForConditionalGeneration",
"KimiVLForConditionalGeneration"):
pytest.skip("Avoid OOM") pytest.skip("Avoid OOM")
# Avoid OOM and reduce initialization time by only using 1 layer # Avoid OOM and reduce initialization time by only using 1 layer

View File

@ -310,7 +310,8 @@ class MultiHeadAttention(nn.Module):
# currently, only torch_sdpa is supported on rocm # currently, only torch_sdpa is supported on rocm
self.attn_backend = _Backend.TORCH_SDPA self.attn_backend = _Backend.TORCH_SDPA
else: else:
if backend in {_Backend.FLASH_ATTN, _Backend.FLASH_ATTN_VLLM_V1}: if backend in (_Backend.FLASH_ATTN, _Backend.FLASH_ATTN_VLLM_V1,
_Backend.FLEX_ATTENTION):
backend = _Backend.XFORMERS backend = _Backend.XFORMERS
self.attn_backend = backend if backend in { self.attn_backend = backend if backend in {

View File

@ -4,7 +4,7 @@
import os import os
from contextlib import contextmanager from contextlib import contextmanager
from functools import cache from functools import cache
from typing import Generator, Optional, Type from typing import Generator, Optional, Union
import torch import torch
@ -79,6 +79,33 @@ def get_global_forced_attn_backend() -> Optional[_Backend]:
return forced_attn_backend return forced_attn_backend
def supports_head_size(
attn_backend: Union[str, type[AttentionBackend]],
head_size: int,
) -> bool:
if isinstance(attn_backend, str):
try:
attn_backend = resolve_obj_by_qualname(attn_backend)
except ImportError:
return False
assert isinstance(attn_backend, type)
# TODO: Update the interface once V0 is removed
if get_supported_head_sizes := getattr(attn_backend,
"get_supported_head_sizes", None):
return head_size in get_supported_head_sizes()
if validate_head_size := getattr(attn_backend, "validate_head_size", None):
try:
validate_head_size(head_size)
return True
except Exception:
return False
raise NotImplementedError(f"{attn_backend.__name__} does not support "
"head size validation")
def get_attn_backend( def get_attn_backend(
head_size: int, head_size: int,
dtype: torch.dtype, dtype: torch.dtype,
@ -87,7 +114,7 @@ def get_attn_backend(
is_attention_free: bool, is_attention_free: bool,
is_blocksparse: bool = False, is_blocksparse: bool = False,
use_mla: bool = False, use_mla: bool = False,
) -> Type[AttentionBackend]: ) -> type[AttentionBackend]:
"""Selects which attention backend to use and lazily imports it.""" """Selects which attention backend to use and lazily imports it."""
# Accessing envs.* behind an @lru_cache decorator can cause the wrong # Accessing envs.* behind an @lru_cache decorator can cause the wrong
# value to be returned from the cache if the value changes between calls. # value to be returned from the cache if the value changes between calls.
@ -115,7 +142,7 @@ def _cached_get_attn_backend(
is_blocksparse: bool = False, is_blocksparse: bool = False,
use_v1: bool = False, use_v1: bool = False,
use_mla: bool = False, use_mla: bool = False,
) -> Type[AttentionBackend]: ) -> type[AttentionBackend]:
if is_blocksparse: if is_blocksparse:
logger.info("Using BlocksparseFlashAttention backend.") logger.info("Using BlocksparseFlashAttention backend.")
from vllm.attention.backends.blocksparse_attn import ( from vllm.attention.backends.blocksparse_attn import (

View File

@ -2319,7 +2319,7 @@ class SchedulerConfig:
if self.max_num_batched_tokens > self.max_num_seqs * self.max_model_len: if self.max_num_batched_tokens > self.max_num_seqs * self.max_model_len:
logger.warning( logger.warning(
"max_num_batched_tokens (%d) exceeds max_num_seqs" "max_num_batched_tokens (%d) exceeds max_num_seqs "
"* max_model_len (%d). This may lead to unexpected behavior.", "* max_model_len (%d). This may lead to unexpected behavior.",
self.max_num_batched_tokens, self.max_num_batched_tokens,
self.max_num_seqs * self.max_model_len) self.max_num_seqs * self.max_model_len)

View File

@ -234,35 +234,44 @@ class CudaPlatformBase(Platform):
return ("vllm.attention.backends." return ("vllm.attention.backends."
"flashmla.FlashMLABackend") "flashmla.FlashMLABackend")
if use_v1: if use_v1:
FLASHINFER_V1 = "vllm.v1.attention.backends.flashinfer.FlashInferBackend" # noqa: E501
FLEX_ATTENTION_V1 = "vllm.v1.attention.backends.flex_attention.FlexAttentionBackend" # noqa: E501
TRITON_ATTN_VLLM_V1 = "vllm.v1.attention.backends.triton_attn.TritonAttentionBackend" # noqa: E501
FLASH_ATTN_V1 = "vllm.v1.attention.backends.flash_attn.FlashAttentionBackend" # noqa: E501
if selected_backend == _Backend.FLASHINFER: if selected_backend == _Backend.FLASHINFER:
logger.info_once("Using FlashInfer backend on V1 engine.") logger.info_once("Using FlashInfer backend on V1 engine.")
return "vllm.v1.attention.backends.flashinfer.FlashInferBackend" return FLASHINFER_V1
elif selected_backend == _Backend.FLEX_ATTENTION: elif selected_backend == _Backend.FLEX_ATTENTION:
logger.info("Using FlexAttenion backend on V1 engine.") logger.info_once("Using FlexAttention backend on V1 engine.")
return "vllm.v1.attention.backends.flex_attention.FlexAttentionBackend" # noqa: E501 return FLEX_ATTENTION_V1
elif selected_backend == _Backend.TRITON_ATTN_VLLM_V1: elif selected_backend == _Backend.TRITON_ATTN_VLLM_V1:
logger.info_once("Using Triton backend on V1 engine.") logger.info_once("Using Triton backend on V1 engine.")
return ("vllm.v1.attention.backends." return TRITON_ATTN_VLLM_V1
"triton_attn.TritonAttentionBackend")
elif selected_backend == _Backend.FLASH_ATTN: elif selected_backend == _Backend.FLASH_ATTN:
logger.info_once("Using Flash Attention backend on V1 engine.") logger.info_once("Using Flash Attention backend on V1 engine.")
return ("vllm.v1.attention.backends." return FLASH_ATTN_V1
"flash_attn.FlashAttentionBackend")
from vllm.attention.selector import supports_head_size
# Default backends for V1 engine # Default backends for V1 engine
# Prefer FlashInfer for Blackwell GPUs if installed # FP32 is only supported by FlexAttention
if dtype not in (torch.float16, torch.bfloat16): if dtype not in (torch.float16, torch.bfloat16):
logger.info_once( logger.info_once(
f"Using FlexAttenion backend for {dtype} on V1 engine.") "Using FlexAttention backend for %s on V1 engine.",
return "vllm.v1.attention.backends.flex_attention.FlexAttentionBackend" # noqa: E501 dtype,
if cls.is_device_capability(100): )
return FLEX_ATTENTION_V1
# Prefer FlashInfer for Blackwell GPUs if installed
if cls.is_device_capability(100) and \
supports_head_size(FLASHINFER_V1, head_size):
try: try:
import flashinfer # noqa: F401 import flashinfer # noqa: F401
logger.info_once( logger.info_once(
"Using FlashInfer backend on V1 engine by default for " "Using FlashInfer backend on V1 engine by default for "
"Blackwell (SM 10.0) GPUs.") "Blackwell (SM 10.0) GPUs.")
return ("vllm.v1.attention.backends." return FLASHINFER_V1
"flashinfer.FlashInferBackend")
except ImportError: except ImportError:
logger.info_once( logger.info_once(
"FlashInfer failed to import for V1 engine on " "FlashInfer failed to import for V1 engine on "
@ -270,10 +279,13 @@ class CudaPlatformBase(Platform):
"install FlashInfer for better performance.") "install FlashInfer for better performance.")
pass pass
# FlashAttention is the default for SM 8.0+ GPUs # FlashAttention is the default for SM 8.0+ GPUs
if cls.has_device_capability(80): if cls.has_device_capability(80) and \
supports_head_size(FLASH_ATTN_V1, head_size):
logger.info_once("Using Flash Attention backend on V1 engine.") logger.info_once("Using Flash Attention backend on V1 engine.")
return ("vllm.v1.attention.backends." return FLASH_ATTN_V1
"flash_attn.FlashAttentionBackend")
logger.info_once("Using FlexAttention backend on V1 engine.")
return FLEX_ATTENTION_V1
# Backends for V0 engine # Backends for V0 engine
if selected_backend == _Backend.FLASHINFER: if selected_backend == _Backend.FLASHINFER:

View File

@ -3,7 +3,8 @@
import numpy as np import numpy as np
import torch import torch
from vllm.attention.backends.abstract import AttentionMetadata from vllm.attention.backends.abstract import (AttentionBackend,
AttentionMetadata)
from vllm.attention.backends.torch_sdpa import (TorchSDPABackendImpl, from vllm.attention.backends.torch_sdpa import (TorchSDPABackendImpl,
TorchSDPAMetadata) TorchSDPAMetadata)
from vllm.attention.backends.utils import CommonAttentionState from vllm.attention.backends.utils import CommonAttentionState
@ -17,9 +18,24 @@ from vllm.v1.worker.cpu_model_runner import CPUModelRunner
from vllm.v1.worker.gpu_input_batch import InputBatch from vllm.v1.worker.gpu_input_batch import InputBatch
class TorchSDPABackend: class TorchSDPABackend(AttentionBackend):
accept_output_buffer: bool = False accept_output_buffer: bool = False
@classmethod
def get_supported_head_sizes(cls) -> list[int]:
return PagedAttention.get_supported_head_sizes()
@classmethod
def validate_head_size(cls, head_size: int) -> None:
supported_head_sizes = cls.get_supported_head_sizes()
if head_size not in supported_head_sizes:
attn_type = cls.__name__.removesuffix("Backend")
raise ValueError(
f"Head size {head_size} is not supported by {attn_type}. "
f"Supported head sizes are: {supported_head_sizes}. "
"Set VLLM_ATTENTION_BACKEND=FLEX_ATTENTION to use "
"FlexAttention backend which supports all head sizes.")
@staticmethod @staticmethod
def get_name() -> str: def get_name() -> str:
return "TORCH_SDPA_VLLM_V1" return "TORCH_SDPA_VLLM_V1"

View File

@ -44,10 +44,21 @@ class FlashAttentionBackend(AttentionBackend):
accept_output_buffer: bool = True accept_output_buffer: bool = True
@staticmethod @classmethod
def get_supported_head_sizes() -> list[int]: def get_supported_head_sizes(cls) -> list[int]:
return [32, 64, 96, 128, 160, 192, 224, 256] return [32, 64, 96, 128, 160, 192, 224, 256]
@classmethod
def validate_head_size(cls, head_size: int) -> None:
supported_head_sizes = cls.get_supported_head_sizes()
if head_size not in supported_head_sizes:
attn_type = cls.__name__.removesuffix("Backend")
raise ValueError(
f"Head size {head_size} is not supported by {attn_type}. "
f"Supported head sizes are: {supported_head_sizes}. "
"Set VLLM_ATTENTION_BACKEND=FLEX_ATTENTION to use "
"FlexAttention backend which supports all head sizes.")
@staticmethod @staticmethod
def get_name() -> str: def get_name() -> str:
return "FLASH_ATTN_VLLM_V1" return "FLASH_ATTN_VLLM_V1"
@ -416,12 +427,7 @@ class FlashAttentionImpl(AttentionImpl):
self.num_queries_per_kv = self.num_heads // self.num_kv_heads self.num_queries_per_kv = self.num_heads // self.num_kv_heads
support_head_sizes = FlashAttentionBackend.get_supported_head_sizes() FlashAttentionBackend.validate_head_size(head_size)
if head_size not in support_head_sizes:
raise ValueError(
f"Head size {head_size} is not supported by FlashAttention. "
f"Supported head sizes are: {support_head_sizes}. "
"Set VLLM_USE_V1=0 to use another attention backend.")
if attn_type != AttentionType.DECODER: if attn_type != AttentionType.DECODER:
raise NotImplementedError("Encoder self-attention and " raise NotImplementedError("Encoder self-attention and "

View File

@ -38,10 +38,22 @@ class FlashInferBackend(AttentionBackend):
accept_output_buffer: bool = True accept_output_buffer: bool = True
@staticmethod @classmethod
def get_supported_head_sizes() -> list[int]: def get_supported_head_sizes(cls) -> list[int]:
# https://github.com/flashinfer-ai/flashinfer/blob/3d55c71a62052c590c130897d3a3db49b14fcc34/include/flashinfer/utils.cuh#L157
return [64, 128, 256] return [64, 128, 256]
@classmethod
def validate_head_size(cls, head_size: int) -> None:
supported_head_sizes = cls.get_supported_head_sizes()
if head_size not in supported_head_sizes:
attn_type = cls.__name__.removesuffix("Backend")
raise ValueError(
f"Head size {head_size} is not supported by {attn_type}. "
f"Supported head sizes are: {supported_head_sizes}. "
"Set VLLM_ATTENTION_BACKEND=FLEX_ATTENTION to use "
"FlexAttention backend which supports all head sizes.")
@staticmethod @staticmethod
def get_name() -> str: def get_name() -> str:
return "FLASHINFER_VLLM_V1" return "FLASHINFER_VLLM_V1"
@ -207,14 +219,8 @@ class FlashInferMetadata:
return self.qo_indptr return self.qo_indptr
def __post_init__(self): def __post_init__(self):
# Refer to if self.head_dim is not None:
# https://github.com/flashinfer-ai/flashinfer/blob/3d55c71a62052c590c130897d3a3db49b14fcc34/include/flashinfer/utils.cuh#L157 FlashInferBackend.validate_head_size(self.head_dim)
supported_head_sizes = FlashInferBackend.get_supported_head_sizes()
if self.head_dim is not None and self.head_dim \
not in supported_head_sizes:
raise ValueError(
f"Only {supported_head_sizes} are supported for head_dim,",
f" received {self.head_dim}.")
class FlashInferMetadataBuilder(AttentionMetadataBuilder[FlashInferMetadata]): class FlashInferMetadataBuilder(AttentionMetadataBuilder[FlashInferMetadata]):

View File

@ -1,7 +1,7 @@
# SPDX-License-Identifier: Apache-2.0 # SPDX-License-Identifier: Apache-2.0
# SPDX-FileCopyrightText: Copyright contributors to the vLLM project # SPDX-FileCopyrightText: Copyright contributors to the vLLM project
"""Attention layer with FlashAttention.""" """Attention layer with FlashAttention."""
from collections import defaultdict
from dataclasses import dataclass from dataclasses import dataclass
from typing import TYPE_CHECKING, Any, Optional from typing import TYPE_CHECKING, Any, Optional
@ -21,9 +21,6 @@ from vllm.v1.attention.backends.utils import (AttentionMetadataBuilder,
from vllm.v1.kv_cache_interface import AttentionSpec from vllm.v1.kv_cache_interface import AttentionSpec
from vllm.v1.worker.block_table import BlockTable from vllm.v1.worker.block_table import BlockTable
if current_platform.is_cuda():
pass
logger = init_logger(__name__) logger = init_logger(__name__)
if TYPE_CHECKING: if TYPE_CHECKING:
@ -45,9 +42,9 @@ def _offsets_to_doc_ids_tensor(offsets: torch.Tensor) -> torch.Tensor:
class FlexAttentionBackend(AttentionBackend): class FlexAttentionBackend(AttentionBackend):
accept_output_buffer: bool = True accept_output_buffer: bool = True
@staticmethod @classmethod
def get_supported_head_sizes() -> list[int]: def validate_head_size(cls, head_size: int) -> None:
return [16, 32, 64, 96, 128, 160, 192, 224, 256] return # FlexAttention supports any head size
@staticmethod @staticmethod
def get_name() -> str: def get_name() -> str:
@ -384,12 +381,8 @@ class FlexAttentionImpl(AttentionImpl):
raise NotImplementedError( raise NotImplementedError(
"FlexAttention does not support kv sharing yet.") "FlexAttention does not support kv sharing yet.")
support_head_sizes = FlexAttentionBackend.get_supported_head_sizes() FlexAttentionBackend.validate_head_size(head_size)
if head_size not in support_head_sizes:
raise ValueError(
f"Head size {head_size} is not supported by FlashAttention. "
f"Supported head sizes are: {support_head_sizes}. "
"Set VLLM_USE_V1=0 to use another attention backend.")
if is_quantized_kv_cache(self.kv_cache_dtype): if is_quantized_kv_cache(self.kv_cache_dtype):
raise NotImplementedError( raise NotImplementedError(
"FlexAttention does not support quantized kv-cache. Yet") "FlexAttention does not support quantized kv-cache. Yet")
@ -464,12 +457,20 @@ class FlexAttentionImpl(AttentionImpl):
# Doesn't work for now -> constraint violation # Doesn't work for now -> constraint violation
# torch._dynamo.try_mark_dynamic(query, 2) # torch._dynamo.try_mark_dynamic(query, 2)
# default M=64, N=64 may run out of shared memory on # default M=64, N=64 may run out of shared memory on some GPUs
# some GPUs with fp32, so we use smaller M and N. # TODO: Explicit configs for each GPU?
extra_kernel_options = { # Not sure how to calculate the shared memory requirement
"BLOCK_M": 32, extra_kernel_options = defaultdict[str, int](lambda: 64)
"BLOCK_N": 32 if query.dtype == torch.float32:
} if query.dtype == torch.float32 else {} extra_kernel_options["BLOCK_M"] //= 2
extra_kernel_options["BLOCK_N"] //= 2
if current_platform.is_cuda():
device_props = torch.cuda.get_device_properties()
max_shared_memory = device_props.shared_memory_per_block_optin
if max_shared_memory < 144 * 1024:
extra_kernel_options["BLOCK_M"] //= 2
extra_kernel_options["BLOCK_N"] //= 2
out = flex_attention_compiled( out = flex_attention_compiled(
query, query,
key_cache, key_cache,

View File

@ -254,10 +254,21 @@ class MLACommonBackend(AttentionBackend):
) -> tuple[int, ...]: ) -> tuple[int, ...]:
return (num_blocks, block_size, head_size) return (num_blocks, block_size, head_size)
@staticmethod @classmethod
def get_supported_head_sizes() -> list[int]: def get_supported_head_sizes(cls) -> list[int]:
return [576] return [576]
@classmethod
def validate_head_size(cls, head_size: int) -> None:
supported_head_sizes = cls.get_supported_head_sizes()
if head_size not in supported_head_sizes:
attn_type = cls.__name__.removesuffix("Backend")
raise ValueError(
f"Head size {head_size} is not supported by {attn_type}. "
f"Supported head sizes are: {supported_head_sizes}. "
"Set VLLM_ATTENTION_BACKEND=FLEX_ATTENTION to use "
"FlexAttention backend which supports all head sizes.")
@dataclass @dataclass
class MLACommonPrefillMetadata: class MLACommonPrefillMetadata:
@ -320,12 +331,8 @@ class MLACommonMetadata(Generic[D]):
prefill: Optional[MLACommonPrefillMetadata] = None prefill: Optional[MLACommonPrefillMetadata] = None
def __post_init__(self): def __post_init__(self):
supported_head_sizes = MLACommonBackend.get_supported_head_sizes() if self.head_dim is not None:
if self.head_dim is not None and self.head_dim \ MLACommonBackend.validate_head_size(self.head_dim)
not in supported_head_sizes:
raise ValueError(
f"Only {supported_head_sizes} are supported for head_dim,",
f"received {self.head_dim}.")
M = TypeVar("M", bound=MLACommonMetadata) M = TypeVar("M", bound=MLACommonMetadata)

View File

@ -314,10 +314,21 @@ class AiterFlashAttentionBackend(AttentionBackend):
accept_output_buffer: bool = True accept_output_buffer: bool = True
@staticmethod @classmethod
def get_supported_head_sizes() -> list[int]: def get_supported_head_sizes(cls) -> list[int]:
return [32, 64, 96, 128, 160, 192, 224, 256] return [32, 64, 96, 128, 160, 192, 224, 256]
@classmethod
def validate_head_size(cls, head_size: int) -> None:
supported_head_sizes = cls.get_supported_head_sizes()
if head_size not in supported_head_sizes:
attn_type = cls.__name__.removesuffix("Backend")
raise ValueError(
f"Head size {head_size} is not supported by {attn_type}. "
f"Supported head sizes are: {supported_head_sizes}. "
"Set VLLM_ATTENTION_BACKEND=FLEX_ATTENTION to use "
"FlexAttention backend which supports all head sizes.")
@staticmethod @staticmethod
def get_name() -> str: def get_name() -> str:
return "FLASH_ATTN_VLLM_V1" return "FLASH_ATTN_VLLM_V1"
@ -428,14 +439,7 @@ class AiterFlashAttentionImpl(AttentionImpl):
assert self.num_heads % self.num_kv_heads == 0 assert self.num_heads % self.num_kv_heads == 0
self.num_queries_per_kv = self.num_heads // self.num_kv_heads self.num_queries_per_kv = self.num_heads // self.num_kv_heads
support_head_sizes = \ AiterFlashAttentionBackend.validate_head_size(head_size)
AiterFlashAttentionBackend.get_supported_head_sizes()
if head_size not in support_head_sizes:
raise ValueError(
f"Head size {head_size} is not supported by "
"AiterFlashAttention. "
f"Supported head sizes are: {support_head_sizes}. "
"Set VLLM_USE_V1=0 to use another attention backend.")
if attn_type != AttentionType.DECODER: if attn_type != AttentionType.DECODER:
raise NotImplementedError("Encoder self-attention and " raise NotImplementedError("Encoder self-attention and "

View File

@ -190,10 +190,21 @@ class TritonAttentionBackend(AttentionBackend):
accept_output_buffer: bool = True accept_output_buffer: bool = True
@staticmethod @classmethod
def get_supported_head_sizes() -> list[int]: def get_supported_head_sizes(cls) -> list[int]:
return [32, 64, 96, 128, 160, 192, 224, 256] return [32, 64, 96, 128, 160, 192, 224, 256]
@classmethod
def validate_head_size(cls, head_size: int) -> None:
supported_head_sizes = cls.get_supported_head_sizes()
if head_size not in supported_head_sizes:
attn_type = cls.__name__.removesuffix("Backend")
raise ValueError(
f"Head size {head_size} is not supported by {attn_type}. "
f"Supported head sizes are: {supported_head_sizes}. "
"Set VLLM_ATTENTION_BACKEND=FLEX_ATTENTION to use "
"FlexAttention backend which supports all head sizes.")
@staticmethod @staticmethod
def get_name() -> str: def get_name() -> str:
return "TRITON_ATTN_VLLM_V1" return "TRITON_ATTN_VLLM_V1"
@ -268,11 +279,7 @@ class TritonAttentionImpl(AttentionImpl):
self.num_queries_per_kv = self.num_heads // self.num_kv_heads self.num_queries_per_kv = self.num_heads // self.num_kv_heads
support_head_sizes = TritonAttentionBackend.get_supported_head_sizes() TritonAttentionBackend.validate_head_size(head_size)
if head_size not in support_head_sizes:
raise ValueError(
f"Head size {head_size} is not supported by TritonAttention. "
f"Supported head sizes are: {support_head_sizes}.")
if attn_type != AttentionType.DECODER: if attn_type != AttentionType.DECODER:
raise NotImplementedError("Encoder self-attention and " raise NotImplementedError("Encoder self-attention and "

View File

@ -12,8 +12,8 @@ from vllm.logger import init_logger
from vllm.model_executor.model_loader import get_model from vllm.model_executor.model_loader import get_model
from vllm.model_executor.models import supports_multimodal from vllm.model_executor.models import supports_multimodal
from vllm.model_executor.models.llama_eagle3 import Eagle3LlamaForCausalLM from vllm.model_executor.models.llama_eagle3 import Eagle3LlamaForCausalLM
from vllm.v1.attention.backends.flash_attn import (CommonAttentionMetadata, from vllm.v1.attention.backends.flash_attn import FlashAttentionMetadata
FlashAttentionMetadata) from vllm.v1.attention.backends.utils import CommonAttentionMetadata
from vllm.v1.kv_cache_interface import KVCacheConfig from vllm.v1.kv_cache_interface import KVCacheConfig
from vllm.v1.sample.metadata import SamplingMetadata from vllm.v1.sample.metadata import SamplingMetadata
from vllm.v1.spec_decode.utils import prepare_eagle_input_kernel from vllm.v1.spec_decode.utils import prepare_eagle_input_kernel