[V1] Support any head size for FlexAttention backend (#20467)

Signed-off-by: DarkLight1337 <tlleungac@connect.ust.hk>
This commit is contained in:
Cyrus Leung 2025-07-07 00:54:36 +08:00 committed by GitHub
parent e202dd2736
commit 9fb52e523a
No known key found for this signature in database
GPG Key ID: B5690EEEBB952194
20 changed files with 202 additions and 118 deletions

View File

@ -107,10 +107,9 @@ fi
if [[ $commands == *" kernels/attention"* ]]; then
commands="${commands} \
--ignore=kernels/attention/stest_attention_selector.py \
--ignore=kernels/attention/test_attention_selector.py \
--ignore=kernels/attention/test_blocksparse_attention.py \
--ignore=kernels/attention/test_encoder_decoder_attn.py \
--ignore=kernels/attention/test_attention_selector.py \
--ignore=kernels/attention/test_flash_attn.py \
--ignore=kernels/attention/test_flashinfer.py \
--ignore=kernels/attention/test_prefix_prefill.py \

View File

@ -626,9 +626,6 @@ Specified using `--task generate`.
!!! note
Only `InternVLChatModel` with Qwen2.5 text backbone (`OpenGVLab/InternVL3-2B`, `OpenGVLab/InternVL2.5-1B` etc) has video inputs support currently.
!!! note
`h2oai/h2ovl-mississippi-2b` will be available in V1 once we support head size 80.
!!! note
To use `TIGER-Lab/Mantis-8B-siglip-llama3`, you have to pass `--hf_overrides '{"architectures": ["MantisForConditionalGeneration"]}'` when running vLLM.
@ -671,11 +668,8 @@ Specified using `--task generate`.
Our PaliGemma implementations have the same problem as Gemma 3 (see above) for both V0 and V1.
!!! note
To use Qwen2.5-Omni, you have to install Hugging Face Transformers library from source via
`pip install git+https://github.com/huggingface/transformers.git`.
Read audio from video pre-processing is currently supported on V0 (but not V1), because overlapping modalities is not yet supported in V1.
`--mm-processor-kwargs '{"use_audio_in_video": true}'`.
For Qwen2.5-Omni, reading audio from video pre-processing (`--mm-processor-kwargs '{"use_audio_in_video": true}'`)
is currently supported on V0 (but not V1), because overlapping modalities is not yet supported in V1.
#### Transcription

View File

@ -98,7 +98,7 @@ def run_blip2(questions: list[str], modality: str) -> ModelRequestData:
# See https://huggingface.co/Salesforce/blip2-opt-2.7b/discussions/15#64ff02f3f8cf9e4f5b038262 #noqa
prompts = [f"Question: {question} Answer:" for question in questions]
engine_args = EngineArgs(
model="Salesforce/blip2-opt-6.7b",
model="Salesforce/blip2-opt-2.7b",
limit_mm_per_prompt={modality: 1},
)
@ -971,7 +971,7 @@ def run_pixtral_hf(questions: list[str], modality: str) -> ModelRequestData:
)
# Qwen
# Qwen-VL
def run_qwen_vl(questions: list[str], modality: str) -> ModelRequestData:
assert modality == "image"

View File

@ -172,7 +172,7 @@ def test_env(
expected = "FLASHINFER_VLLM_V1" if use_v1 else name
assert backend.get_name() == expected
else:
backend = get_attn_backend(16,
backend = get_attn_backend(32,
torch.float16,
torch.float16,
block_size,
@ -181,6 +181,17 @@ def test_env(
expected = "FLASH_ATTN_VLLM_V1" if use_v1 else name
assert backend.get_name() == expected
if use_v1:
backend = get_attn_backend(16,
torch.float16,
torch.float16,
block_size,
False,
use_mla=use_mla)
assert backend.get_name() == "FLEX_ATTENTION", (
"Should fallback to FlexAttention if head size is "
"not supported by FlashAttention")
@pytest.mark.parametrize("device", ["cpu", "cuda"])
@pytest.mark.parametrize("use_v1", [True, False])

View File

@ -33,9 +33,6 @@ if current_platform.is_rocm():
os.environ["VLLM_USE_TRITON_FLASH_ATTN"] = "0"
REQUIRES_V0_MODELS = [
# V1 Test: no way to fall back for head_dim = 80
# https://github.com/vllm-project/vllm/issues/14524
"qwen_vl",
# V1 Test: not enough KV cache space in C1.
"fuyu",
]
@ -221,8 +218,7 @@ VLM_TEST_SETTINGS = {
marks=[large_gpu_mark(min_gb=32)],
),
"blip2": VLMTestInfo(
# TODO: Change back to 2.7b once head_dim = 80 is supported
models=["Salesforce/blip2-opt-6.7b"],
models=["Salesforce/blip2-opt-2.7b"],
test_type=VLMTestType.IMAGE,
prompt_formatter=lambda img_prompt: f"Question: {img_prompt} Answer:",
img_idx_to_prompt=lambda idx: "",
@ -340,8 +336,7 @@ VLM_TEST_SETTINGS = {
"h2ovl": VLMTestInfo(
models = [
"h2oai/h2ovl-mississippi-800m",
# TODO: Re-enable once head_dim = 80 is supported
# "h2oai/h2ovl-mississippi-2b",
"h2oai/h2ovl-mississippi-2b",
],
test_type=(VLMTestType.IMAGE, VLMTestType.MULTI_IMAGE),
prompt_formatter=lambda img_prompt: f"<|prompt|>{img_prompt}<|end|><|answer|>", # noqa: E501

View File

@ -83,7 +83,7 @@ MODELS = [
QWEN2_CONFIG,
PHI3_CONFIG,
GPT2_CONFIG,
# STABLELM_CONFIG, # enable this when v1 support head_size=80
STABLELM_CONFIG,
DOLPHIN_CONFIG,
# STARCODER_CONFIG, # broken
]

View File

@ -240,8 +240,9 @@ _TEXT_GENERATION_EXAMPLE_MODELS = {
"OrionForCausalLM": _HfExamplesInfo("OrionStarAI/Orion-14B-Chat",
trust_remote_code=True),
"PersimmonForCausalLM": _HfExamplesInfo("adept/persimmon-8b-chat"),
"PhiForCausalLM": _HfExamplesInfo("microsoft/phi-2", v0_only=True),
"PhiForCausalLM": _HfExamplesInfo("microsoft/phi-2"),
"Phi3ForCausalLM": _HfExamplesInfo("microsoft/Phi-3-mini-4k-instruct"),
# Blocksparse attention not supported in V1 yet
"Phi3SmallForCausalLM": _HfExamplesInfo("microsoft/Phi-3-small-8k-instruct",
trust_remote_code=True,
v0_only=True),
@ -258,10 +259,8 @@ _TEXT_GENERATION_EXAMPLE_MODELS = {
"Qwen3MoeForCausalLM": _HfExamplesInfo("Qwen/Qwen3-30B-A3B"),
"Qwen3ForSequenceClassification": _HfExamplesInfo("tomaarsen/Qwen3-Reranker-0.6B-seq-cls"), # noqa: E501
"RWForCausalLM": _HfExamplesInfo("tiiuae/falcon-40b"),
"StableLMEpochForCausalLM": _HfExamplesInfo("stabilityai/stablelm-zephyr-3b", # noqa: E501
v0_only=True),
"StableLmForCausalLM": _HfExamplesInfo("stabilityai/stablelm-3b-4e1t",
v0_only=True),
"StableLMEpochForCausalLM": _HfExamplesInfo("stabilityai/stablelm-zephyr-3b"), # noqa: E501
"StableLmForCausalLM": _HfExamplesInfo("stabilityai/stablelm-3b-4e1t"),
"Starcoder2ForCausalLM": _HfExamplesInfo("bigcode/starcoder2-3b"),
"SolarForCausalLM": _HfExamplesInfo("upstage/solar-pro-preview-instruct"),
"TeleChat2ForCausalLM": _HfExamplesInfo("Tele-AI/TeleChat2-3B",
@ -330,8 +329,7 @@ _MULTIMODAL_EXAMPLE_MODELS = {
"AriaForConditionalGeneration": _HfExamplesInfo("rhymes-ai/Aria"),
"AyaVisionForConditionalGeneration": _HfExamplesInfo("CohereForAI/aya-vision-8b"), # noqa: E501
"Blip2ForConditionalGeneration": _HfExamplesInfo("Salesforce/blip2-opt-2.7b", # noqa: E501
extras={"6b": "Salesforce/blip2-opt-6.7b"}, # noqa: E501
v0_only=True),
extras={"6b": "Salesforce/blip2-opt-6.7b"}), # noqa: E501
"ChameleonForConditionalGeneration": _HfExamplesInfo("facebook/chameleon-7b"), # noqa: E501
"DeepseekVLV2ForCausalLM": _HfExamplesInfo("deepseek-ai/deepseek-vl2-tiny", # noqa: E501
extras={"fork": "Isotr0py/deepseek-vl2-tiny"}, # noqa: E501
@ -359,8 +357,7 @@ _MULTIMODAL_EXAMPLE_MODELS = {
trust_remote_code=True),
"KimiVLForConditionalGeneration": _HfExamplesInfo("moonshotai/Kimi-VL-A3B-Instruct", # noqa: E501
extras={"thinking": "moonshotai/Kimi-VL-A3B-Thinking"}, # noqa: E501
trust_remote_code=True,
v0_only=True),
trust_remote_code=True),
"Llama4ForConditionalGeneration": _HfExamplesInfo("meta-llama/Llama-4-Scout-17B-16E-Instruct", # noqa: E501
max_model_len=10240),
"LlavaForConditionalGeneration": _HfExamplesInfo("llava-hf/llava-1.5-7b-hf",

View File

@ -22,7 +22,8 @@ def test_can_initialize(model_arch: str, monkeypatch: pytest.MonkeyPatch):
model_info.check_transformers_version(on_fail="skip")
# FIXME: Possible memory leak in the previous tests?
if model_arch == "GraniteSpeechForConditionalGeneration":
if model_arch in ("GraniteSpeechForConditionalGeneration",
"KimiVLForConditionalGeneration"):
pytest.skip("Avoid OOM")
# Avoid OOM and reduce initialization time by only using 1 layer

View File

@ -310,7 +310,8 @@ class MultiHeadAttention(nn.Module):
# currently, only torch_sdpa is supported on rocm
self.attn_backend = _Backend.TORCH_SDPA
else:
if backend in {_Backend.FLASH_ATTN, _Backend.FLASH_ATTN_VLLM_V1}:
if backend in (_Backend.FLASH_ATTN, _Backend.FLASH_ATTN_VLLM_V1,
_Backend.FLEX_ATTENTION):
backend = _Backend.XFORMERS
self.attn_backend = backend if backend in {

View File

@ -4,7 +4,7 @@
import os
from contextlib import contextmanager
from functools import cache
from typing import Generator, Optional, Type
from typing import Generator, Optional, Union
import torch
@ -79,6 +79,33 @@ def get_global_forced_attn_backend() -> Optional[_Backend]:
return forced_attn_backend
def supports_head_size(
attn_backend: Union[str, type[AttentionBackend]],
head_size: int,
) -> bool:
if isinstance(attn_backend, str):
try:
attn_backend = resolve_obj_by_qualname(attn_backend)
except ImportError:
return False
assert isinstance(attn_backend, type)
# TODO: Update the interface once V0 is removed
if get_supported_head_sizes := getattr(attn_backend,
"get_supported_head_sizes", None):
return head_size in get_supported_head_sizes()
if validate_head_size := getattr(attn_backend, "validate_head_size", None):
try:
validate_head_size(head_size)
return True
except Exception:
return False
raise NotImplementedError(f"{attn_backend.__name__} does not support "
"head size validation")
def get_attn_backend(
head_size: int,
dtype: torch.dtype,
@ -87,7 +114,7 @@ def get_attn_backend(
is_attention_free: bool,
is_blocksparse: bool = False,
use_mla: bool = False,
) -> Type[AttentionBackend]:
) -> type[AttentionBackend]:
"""Selects which attention backend to use and lazily imports it."""
# Accessing envs.* behind an @lru_cache decorator can cause the wrong
# value to be returned from the cache if the value changes between calls.
@ -115,7 +142,7 @@ def _cached_get_attn_backend(
is_blocksparse: bool = False,
use_v1: bool = False,
use_mla: bool = False,
) -> Type[AttentionBackend]:
) -> type[AttentionBackend]:
if is_blocksparse:
logger.info("Using BlocksparseFlashAttention backend.")
from vllm.attention.backends.blocksparse_attn import (

View File

@ -2319,7 +2319,7 @@ class SchedulerConfig:
if self.max_num_batched_tokens > self.max_num_seqs * self.max_model_len:
logger.warning(
"max_num_batched_tokens (%d) exceeds max_num_seqs"
"max_num_batched_tokens (%d) exceeds max_num_seqs "
"* max_model_len (%d). This may lead to unexpected behavior.",
self.max_num_batched_tokens,
self.max_num_seqs * self.max_model_len)

View File

@ -234,35 +234,44 @@ class CudaPlatformBase(Platform):
return ("vllm.attention.backends."
"flashmla.FlashMLABackend")
if use_v1:
FLASHINFER_V1 = "vllm.v1.attention.backends.flashinfer.FlashInferBackend" # noqa: E501
FLEX_ATTENTION_V1 = "vllm.v1.attention.backends.flex_attention.FlexAttentionBackend" # noqa: E501
TRITON_ATTN_VLLM_V1 = "vllm.v1.attention.backends.triton_attn.TritonAttentionBackend" # noqa: E501
FLASH_ATTN_V1 = "vllm.v1.attention.backends.flash_attn.FlashAttentionBackend" # noqa: E501
if selected_backend == _Backend.FLASHINFER:
logger.info_once("Using FlashInfer backend on V1 engine.")
return "vllm.v1.attention.backends.flashinfer.FlashInferBackend"
return FLASHINFER_V1
elif selected_backend == _Backend.FLEX_ATTENTION:
logger.info("Using FlexAttenion backend on V1 engine.")
return "vllm.v1.attention.backends.flex_attention.FlexAttentionBackend" # noqa: E501
logger.info_once("Using FlexAttention backend on V1 engine.")
return FLEX_ATTENTION_V1
elif selected_backend == _Backend.TRITON_ATTN_VLLM_V1:
logger.info_once("Using Triton backend on V1 engine.")
return ("vllm.v1.attention.backends."
"triton_attn.TritonAttentionBackend")
return TRITON_ATTN_VLLM_V1
elif selected_backend == _Backend.FLASH_ATTN:
logger.info_once("Using Flash Attention backend on V1 engine.")
return ("vllm.v1.attention.backends."
"flash_attn.FlashAttentionBackend")
return FLASH_ATTN_V1
from vllm.attention.selector import supports_head_size
# Default backends for V1 engine
# Prefer FlashInfer for Blackwell GPUs if installed
# FP32 is only supported by FlexAttention
if dtype not in (torch.float16, torch.bfloat16):
logger.info_once(
f"Using FlexAttenion backend for {dtype} on V1 engine.")
return "vllm.v1.attention.backends.flex_attention.FlexAttentionBackend" # noqa: E501
if cls.is_device_capability(100):
"Using FlexAttention backend for %s on V1 engine.",
dtype,
)
return FLEX_ATTENTION_V1
# Prefer FlashInfer for Blackwell GPUs if installed
if cls.is_device_capability(100) and \
supports_head_size(FLASHINFER_V1, head_size):
try:
import flashinfer # noqa: F401
logger.info_once(
"Using FlashInfer backend on V1 engine by default for "
"Blackwell (SM 10.0) GPUs.")
return ("vllm.v1.attention.backends."
"flashinfer.FlashInferBackend")
return FLASHINFER_V1
except ImportError:
logger.info_once(
"FlashInfer failed to import for V1 engine on "
@ -270,10 +279,13 @@ class CudaPlatformBase(Platform):
"install FlashInfer for better performance.")
pass
# FlashAttention is the default for SM 8.0+ GPUs
if cls.has_device_capability(80):
if cls.has_device_capability(80) and \
supports_head_size(FLASH_ATTN_V1, head_size):
logger.info_once("Using Flash Attention backend on V1 engine.")
return ("vllm.v1.attention.backends."
"flash_attn.FlashAttentionBackend")
return FLASH_ATTN_V1
logger.info_once("Using FlexAttention backend on V1 engine.")
return FLEX_ATTENTION_V1
# Backends for V0 engine
if selected_backend == _Backend.FLASHINFER:

View File

@ -3,7 +3,8 @@
import numpy as np
import torch
from vllm.attention.backends.abstract import AttentionMetadata
from vllm.attention.backends.abstract import (AttentionBackend,
AttentionMetadata)
from vllm.attention.backends.torch_sdpa import (TorchSDPABackendImpl,
TorchSDPAMetadata)
from vllm.attention.backends.utils import CommonAttentionState
@ -17,9 +18,24 @@ from vllm.v1.worker.cpu_model_runner import CPUModelRunner
from vllm.v1.worker.gpu_input_batch import InputBatch
class TorchSDPABackend:
class TorchSDPABackend(AttentionBackend):
accept_output_buffer: bool = False
@classmethod
def get_supported_head_sizes(cls) -> list[int]:
return PagedAttention.get_supported_head_sizes()
@classmethod
def validate_head_size(cls, head_size: int) -> None:
supported_head_sizes = cls.get_supported_head_sizes()
if head_size not in supported_head_sizes:
attn_type = cls.__name__.removesuffix("Backend")
raise ValueError(
f"Head size {head_size} is not supported by {attn_type}. "
f"Supported head sizes are: {supported_head_sizes}. "
"Set VLLM_ATTENTION_BACKEND=FLEX_ATTENTION to use "
"FlexAttention backend which supports all head sizes.")
@staticmethod
def get_name() -> str:
return "TORCH_SDPA_VLLM_V1"

View File

@ -44,10 +44,21 @@ class FlashAttentionBackend(AttentionBackend):
accept_output_buffer: bool = True
@staticmethod
def get_supported_head_sizes() -> list[int]:
@classmethod
def get_supported_head_sizes(cls) -> list[int]:
return [32, 64, 96, 128, 160, 192, 224, 256]
@classmethod
def validate_head_size(cls, head_size: int) -> None:
supported_head_sizes = cls.get_supported_head_sizes()
if head_size not in supported_head_sizes:
attn_type = cls.__name__.removesuffix("Backend")
raise ValueError(
f"Head size {head_size} is not supported by {attn_type}. "
f"Supported head sizes are: {supported_head_sizes}. "
"Set VLLM_ATTENTION_BACKEND=FLEX_ATTENTION to use "
"FlexAttention backend which supports all head sizes.")
@staticmethod
def get_name() -> str:
return "FLASH_ATTN_VLLM_V1"
@ -416,12 +427,7 @@ class FlashAttentionImpl(AttentionImpl):
self.num_queries_per_kv = self.num_heads // self.num_kv_heads
support_head_sizes = FlashAttentionBackend.get_supported_head_sizes()
if head_size not in support_head_sizes:
raise ValueError(
f"Head size {head_size} is not supported by FlashAttention. "
f"Supported head sizes are: {support_head_sizes}. "
"Set VLLM_USE_V1=0 to use another attention backend.")
FlashAttentionBackend.validate_head_size(head_size)
if attn_type != AttentionType.DECODER:
raise NotImplementedError("Encoder self-attention and "

View File

@ -38,10 +38,22 @@ class FlashInferBackend(AttentionBackend):
accept_output_buffer: bool = True
@staticmethod
def get_supported_head_sizes() -> list[int]:
@classmethod
def get_supported_head_sizes(cls) -> list[int]:
# https://github.com/flashinfer-ai/flashinfer/blob/3d55c71a62052c590c130897d3a3db49b14fcc34/include/flashinfer/utils.cuh#L157
return [64, 128, 256]
@classmethod
def validate_head_size(cls, head_size: int) -> None:
supported_head_sizes = cls.get_supported_head_sizes()
if head_size not in supported_head_sizes:
attn_type = cls.__name__.removesuffix("Backend")
raise ValueError(
f"Head size {head_size} is not supported by {attn_type}. "
f"Supported head sizes are: {supported_head_sizes}. "
"Set VLLM_ATTENTION_BACKEND=FLEX_ATTENTION to use "
"FlexAttention backend which supports all head sizes.")
@staticmethod
def get_name() -> str:
return "FLASHINFER_VLLM_V1"
@ -207,14 +219,8 @@ class FlashInferMetadata:
return self.qo_indptr
def __post_init__(self):
# Refer to
# https://github.com/flashinfer-ai/flashinfer/blob/3d55c71a62052c590c130897d3a3db49b14fcc34/include/flashinfer/utils.cuh#L157
supported_head_sizes = FlashInferBackend.get_supported_head_sizes()
if self.head_dim is not None and self.head_dim \
not in supported_head_sizes:
raise ValueError(
f"Only {supported_head_sizes} are supported for head_dim,",
f" received {self.head_dim}.")
if self.head_dim is not None:
FlashInferBackend.validate_head_size(self.head_dim)
class FlashInferMetadataBuilder(AttentionMetadataBuilder[FlashInferMetadata]):

View File

@ -1,7 +1,7 @@
# SPDX-License-Identifier: Apache-2.0
# SPDX-FileCopyrightText: Copyright contributors to the vLLM project
"""Attention layer with FlashAttention."""
from collections import defaultdict
from dataclasses import dataclass
from typing import TYPE_CHECKING, Any, Optional
@ -21,9 +21,6 @@ from vllm.v1.attention.backends.utils import (AttentionMetadataBuilder,
from vllm.v1.kv_cache_interface import AttentionSpec
from vllm.v1.worker.block_table import BlockTable
if current_platform.is_cuda():
pass
logger = init_logger(__name__)
if TYPE_CHECKING:
@ -45,9 +42,9 @@ def _offsets_to_doc_ids_tensor(offsets: torch.Tensor) -> torch.Tensor:
class FlexAttentionBackend(AttentionBackend):
accept_output_buffer: bool = True
@staticmethod
def get_supported_head_sizes() -> list[int]:
return [16, 32, 64, 96, 128, 160, 192, 224, 256]
@classmethod
def validate_head_size(cls, head_size: int) -> None:
return # FlexAttention supports any head size
@staticmethod
def get_name() -> str:
@ -384,12 +381,8 @@ class FlexAttentionImpl(AttentionImpl):
raise NotImplementedError(
"FlexAttention does not support kv sharing yet.")
support_head_sizes = FlexAttentionBackend.get_supported_head_sizes()
if head_size not in support_head_sizes:
raise ValueError(
f"Head size {head_size} is not supported by FlashAttention. "
f"Supported head sizes are: {support_head_sizes}. "
"Set VLLM_USE_V1=0 to use another attention backend.")
FlexAttentionBackend.validate_head_size(head_size)
if is_quantized_kv_cache(self.kv_cache_dtype):
raise NotImplementedError(
"FlexAttention does not support quantized kv-cache. Yet")
@ -464,12 +457,20 @@ class FlexAttentionImpl(AttentionImpl):
# Doesn't work for now -> constraint violation
# torch._dynamo.try_mark_dynamic(query, 2)
# default M=64, N=64 may run out of shared memory on
# some GPUs with fp32, so we use smaller M and N.
extra_kernel_options = {
"BLOCK_M": 32,
"BLOCK_N": 32
} if query.dtype == torch.float32 else {}
# default M=64, N=64 may run out of shared memory on some GPUs
# TODO: Explicit configs for each GPU?
# Not sure how to calculate the shared memory requirement
extra_kernel_options = defaultdict[str, int](lambda: 64)
if query.dtype == torch.float32:
extra_kernel_options["BLOCK_M"] //= 2
extra_kernel_options["BLOCK_N"] //= 2
if current_platform.is_cuda():
device_props = torch.cuda.get_device_properties()
max_shared_memory = device_props.shared_memory_per_block_optin
if max_shared_memory < 144 * 1024:
extra_kernel_options["BLOCK_M"] //= 2
extra_kernel_options["BLOCK_N"] //= 2
out = flex_attention_compiled(
query,
key_cache,

View File

@ -254,10 +254,21 @@ class MLACommonBackend(AttentionBackend):
) -> tuple[int, ...]:
return (num_blocks, block_size, head_size)
@staticmethod
def get_supported_head_sizes() -> list[int]:
@classmethod
def get_supported_head_sizes(cls) -> list[int]:
return [576]
@classmethod
def validate_head_size(cls, head_size: int) -> None:
supported_head_sizes = cls.get_supported_head_sizes()
if head_size not in supported_head_sizes:
attn_type = cls.__name__.removesuffix("Backend")
raise ValueError(
f"Head size {head_size} is not supported by {attn_type}. "
f"Supported head sizes are: {supported_head_sizes}. "
"Set VLLM_ATTENTION_BACKEND=FLEX_ATTENTION to use "
"FlexAttention backend which supports all head sizes.")
@dataclass
class MLACommonPrefillMetadata:
@ -320,12 +331,8 @@ class MLACommonMetadata(Generic[D]):
prefill: Optional[MLACommonPrefillMetadata] = None
def __post_init__(self):
supported_head_sizes = MLACommonBackend.get_supported_head_sizes()
if self.head_dim is not None and self.head_dim \
not in supported_head_sizes:
raise ValueError(
f"Only {supported_head_sizes} are supported for head_dim,",
f"received {self.head_dim}.")
if self.head_dim is not None:
MLACommonBackend.validate_head_size(self.head_dim)
M = TypeVar("M", bound=MLACommonMetadata)

View File

@ -314,10 +314,21 @@ class AiterFlashAttentionBackend(AttentionBackend):
accept_output_buffer: bool = True
@staticmethod
def get_supported_head_sizes() -> list[int]:
@classmethod
def get_supported_head_sizes(cls) -> list[int]:
return [32, 64, 96, 128, 160, 192, 224, 256]
@classmethod
def validate_head_size(cls, head_size: int) -> None:
supported_head_sizes = cls.get_supported_head_sizes()
if head_size not in supported_head_sizes:
attn_type = cls.__name__.removesuffix("Backend")
raise ValueError(
f"Head size {head_size} is not supported by {attn_type}. "
f"Supported head sizes are: {supported_head_sizes}. "
"Set VLLM_ATTENTION_BACKEND=FLEX_ATTENTION to use "
"FlexAttention backend which supports all head sizes.")
@staticmethod
def get_name() -> str:
return "FLASH_ATTN_VLLM_V1"
@ -428,14 +439,7 @@ class AiterFlashAttentionImpl(AttentionImpl):
assert self.num_heads % self.num_kv_heads == 0
self.num_queries_per_kv = self.num_heads // self.num_kv_heads
support_head_sizes = \
AiterFlashAttentionBackend.get_supported_head_sizes()
if head_size not in support_head_sizes:
raise ValueError(
f"Head size {head_size} is not supported by "
"AiterFlashAttention. "
f"Supported head sizes are: {support_head_sizes}. "
"Set VLLM_USE_V1=0 to use another attention backend.")
AiterFlashAttentionBackend.validate_head_size(head_size)
if attn_type != AttentionType.DECODER:
raise NotImplementedError("Encoder self-attention and "

View File

@ -190,10 +190,21 @@ class TritonAttentionBackend(AttentionBackend):
accept_output_buffer: bool = True
@staticmethod
def get_supported_head_sizes() -> list[int]:
@classmethod
def get_supported_head_sizes(cls) -> list[int]:
return [32, 64, 96, 128, 160, 192, 224, 256]
@classmethod
def validate_head_size(cls, head_size: int) -> None:
supported_head_sizes = cls.get_supported_head_sizes()
if head_size not in supported_head_sizes:
attn_type = cls.__name__.removesuffix("Backend")
raise ValueError(
f"Head size {head_size} is not supported by {attn_type}. "
f"Supported head sizes are: {supported_head_sizes}. "
"Set VLLM_ATTENTION_BACKEND=FLEX_ATTENTION to use "
"FlexAttention backend which supports all head sizes.")
@staticmethod
def get_name() -> str:
return "TRITON_ATTN_VLLM_V1"
@ -268,11 +279,7 @@ class TritonAttentionImpl(AttentionImpl):
self.num_queries_per_kv = self.num_heads // self.num_kv_heads
support_head_sizes = TritonAttentionBackend.get_supported_head_sizes()
if head_size not in support_head_sizes:
raise ValueError(
f"Head size {head_size} is not supported by TritonAttention. "
f"Supported head sizes are: {support_head_sizes}.")
TritonAttentionBackend.validate_head_size(head_size)
if attn_type != AttentionType.DECODER:
raise NotImplementedError("Encoder self-attention and "

View File

@ -12,8 +12,8 @@ from vllm.logger import init_logger
from vllm.model_executor.model_loader import get_model
from vllm.model_executor.models import supports_multimodal
from vllm.model_executor.models.llama_eagle3 import Eagle3LlamaForCausalLM
from vllm.v1.attention.backends.flash_attn import (CommonAttentionMetadata,
FlashAttentionMetadata)
from vllm.v1.attention.backends.flash_attn import FlashAttentionMetadata
from vllm.v1.attention.backends.utils import CommonAttentionMetadata
from vllm.v1.kv_cache_interface import KVCacheConfig
from vllm.v1.sample.metadata import SamplingMetadata
from vllm.v1.spec_decode.utils import prepare_eagle_input_kernel