mirror of
https://git.datalinker.icu/vllm-project/vllm.git
synced 2026-03-16 13:57:12 +08:00
[V1] Support any head size for FlexAttention backend (#20467)
Signed-off-by: DarkLight1337 <tlleungac@connect.ust.hk>
This commit is contained in:
parent
e202dd2736
commit
9fb52e523a
@ -107,10 +107,9 @@ fi
|
||||
|
||||
if [[ $commands == *" kernels/attention"* ]]; then
|
||||
commands="${commands} \
|
||||
--ignore=kernels/attention/stest_attention_selector.py \
|
||||
--ignore=kernels/attention/test_attention_selector.py \
|
||||
--ignore=kernels/attention/test_blocksparse_attention.py \
|
||||
--ignore=kernels/attention/test_encoder_decoder_attn.py \
|
||||
--ignore=kernels/attention/test_attention_selector.py \
|
||||
--ignore=kernels/attention/test_flash_attn.py \
|
||||
--ignore=kernels/attention/test_flashinfer.py \
|
||||
--ignore=kernels/attention/test_prefix_prefill.py \
|
||||
|
||||
@ -626,9 +626,6 @@ Specified using `--task generate`.
|
||||
!!! note
|
||||
Only `InternVLChatModel` with Qwen2.5 text backbone (`OpenGVLab/InternVL3-2B`, `OpenGVLab/InternVL2.5-1B` etc) has video inputs support currently.
|
||||
|
||||
!!! note
|
||||
`h2oai/h2ovl-mississippi-2b` will be available in V1 once we support head size 80.
|
||||
|
||||
!!! note
|
||||
To use `TIGER-Lab/Mantis-8B-siglip-llama3`, you have to pass `--hf_overrides '{"architectures": ["MantisForConditionalGeneration"]}'` when running vLLM.
|
||||
|
||||
@ -671,11 +668,8 @@ Specified using `--task generate`.
|
||||
Our PaliGemma implementations have the same problem as Gemma 3 (see above) for both V0 and V1.
|
||||
|
||||
!!! note
|
||||
To use Qwen2.5-Omni, you have to install Hugging Face Transformers library from source via
|
||||
`pip install git+https://github.com/huggingface/transformers.git`.
|
||||
|
||||
Read audio from video pre-processing is currently supported on V0 (but not V1), because overlapping modalities is not yet supported in V1.
|
||||
`--mm-processor-kwargs '{"use_audio_in_video": true}'`.
|
||||
For Qwen2.5-Omni, reading audio from video pre-processing (`--mm-processor-kwargs '{"use_audio_in_video": true}'`)
|
||||
is currently supported on V0 (but not V1), because overlapping modalities is not yet supported in V1.
|
||||
|
||||
#### Transcription
|
||||
|
||||
|
||||
@ -98,7 +98,7 @@ def run_blip2(questions: list[str], modality: str) -> ModelRequestData:
|
||||
# See https://huggingface.co/Salesforce/blip2-opt-2.7b/discussions/15#64ff02f3f8cf9e4f5b038262 #noqa
|
||||
prompts = [f"Question: {question} Answer:" for question in questions]
|
||||
engine_args = EngineArgs(
|
||||
model="Salesforce/blip2-opt-6.7b",
|
||||
model="Salesforce/blip2-opt-2.7b",
|
||||
limit_mm_per_prompt={modality: 1},
|
||||
)
|
||||
|
||||
@ -971,7 +971,7 @@ def run_pixtral_hf(questions: list[str], modality: str) -> ModelRequestData:
|
||||
)
|
||||
|
||||
|
||||
# Qwen
|
||||
# Qwen-VL
|
||||
def run_qwen_vl(questions: list[str], modality: str) -> ModelRequestData:
|
||||
assert modality == "image"
|
||||
|
||||
|
||||
@ -172,7 +172,7 @@ def test_env(
|
||||
expected = "FLASHINFER_VLLM_V1" if use_v1 else name
|
||||
assert backend.get_name() == expected
|
||||
else:
|
||||
backend = get_attn_backend(16,
|
||||
backend = get_attn_backend(32,
|
||||
torch.float16,
|
||||
torch.float16,
|
||||
block_size,
|
||||
@ -181,6 +181,17 @@ def test_env(
|
||||
expected = "FLASH_ATTN_VLLM_V1" if use_v1 else name
|
||||
assert backend.get_name() == expected
|
||||
|
||||
if use_v1:
|
||||
backend = get_attn_backend(16,
|
||||
torch.float16,
|
||||
torch.float16,
|
||||
block_size,
|
||||
False,
|
||||
use_mla=use_mla)
|
||||
assert backend.get_name() == "FLEX_ATTENTION", (
|
||||
"Should fallback to FlexAttention if head size is "
|
||||
"not supported by FlashAttention")
|
||||
|
||||
|
||||
@pytest.mark.parametrize("device", ["cpu", "cuda"])
|
||||
@pytest.mark.parametrize("use_v1", [True, False])
|
||||
|
||||
@ -33,9 +33,6 @@ if current_platform.is_rocm():
|
||||
os.environ["VLLM_USE_TRITON_FLASH_ATTN"] = "0"
|
||||
|
||||
REQUIRES_V0_MODELS = [
|
||||
# V1 Test: no way to fall back for head_dim = 80
|
||||
# https://github.com/vllm-project/vllm/issues/14524
|
||||
"qwen_vl",
|
||||
# V1 Test: not enough KV cache space in C1.
|
||||
"fuyu",
|
||||
]
|
||||
@ -221,8 +218,7 @@ VLM_TEST_SETTINGS = {
|
||||
marks=[large_gpu_mark(min_gb=32)],
|
||||
),
|
||||
"blip2": VLMTestInfo(
|
||||
# TODO: Change back to 2.7b once head_dim = 80 is supported
|
||||
models=["Salesforce/blip2-opt-6.7b"],
|
||||
models=["Salesforce/blip2-opt-2.7b"],
|
||||
test_type=VLMTestType.IMAGE,
|
||||
prompt_formatter=lambda img_prompt: f"Question: {img_prompt} Answer:",
|
||||
img_idx_to_prompt=lambda idx: "",
|
||||
@ -340,8 +336,7 @@ VLM_TEST_SETTINGS = {
|
||||
"h2ovl": VLMTestInfo(
|
||||
models = [
|
||||
"h2oai/h2ovl-mississippi-800m",
|
||||
# TODO: Re-enable once head_dim = 80 is supported
|
||||
# "h2oai/h2ovl-mississippi-2b",
|
||||
"h2oai/h2ovl-mississippi-2b",
|
||||
],
|
||||
test_type=(VLMTestType.IMAGE, VLMTestType.MULTI_IMAGE),
|
||||
prompt_formatter=lambda img_prompt: f"<|prompt|>{img_prompt}<|end|><|answer|>", # noqa: E501
|
||||
|
||||
@ -83,7 +83,7 @@ MODELS = [
|
||||
QWEN2_CONFIG,
|
||||
PHI3_CONFIG,
|
||||
GPT2_CONFIG,
|
||||
# STABLELM_CONFIG, # enable this when v1 support head_size=80
|
||||
STABLELM_CONFIG,
|
||||
DOLPHIN_CONFIG,
|
||||
# STARCODER_CONFIG, # broken
|
||||
]
|
||||
|
||||
@ -240,8 +240,9 @@ _TEXT_GENERATION_EXAMPLE_MODELS = {
|
||||
"OrionForCausalLM": _HfExamplesInfo("OrionStarAI/Orion-14B-Chat",
|
||||
trust_remote_code=True),
|
||||
"PersimmonForCausalLM": _HfExamplesInfo("adept/persimmon-8b-chat"),
|
||||
"PhiForCausalLM": _HfExamplesInfo("microsoft/phi-2", v0_only=True),
|
||||
"PhiForCausalLM": _HfExamplesInfo("microsoft/phi-2"),
|
||||
"Phi3ForCausalLM": _HfExamplesInfo("microsoft/Phi-3-mini-4k-instruct"),
|
||||
# Blocksparse attention not supported in V1 yet
|
||||
"Phi3SmallForCausalLM": _HfExamplesInfo("microsoft/Phi-3-small-8k-instruct",
|
||||
trust_remote_code=True,
|
||||
v0_only=True),
|
||||
@ -258,10 +259,8 @@ _TEXT_GENERATION_EXAMPLE_MODELS = {
|
||||
"Qwen3MoeForCausalLM": _HfExamplesInfo("Qwen/Qwen3-30B-A3B"),
|
||||
"Qwen3ForSequenceClassification": _HfExamplesInfo("tomaarsen/Qwen3-Reranker-0.6B-seq-cls"), # noqa: E501
|
||||
"RWForCausalLM": _HfExamplesInfo("tiiuae/falcon-40b"),
|
||||
"StableLMEpochForCausalLM": _HfExamplesInfo("stabilityai/stablelm-zephyr-3b", # noqa: E501
|
||||
v0_only=True),
|
||||
"StableLmForCausalLM": _HfExamplesInfo("stabilityai/stablelm-3b-4e1t",
|
||||
v0_only=True),
|
||||
"StableLMEpochForCausalLM": _HfExamplesInfo("stabilityai/stablelm-zephyr-3b"), # noqa: E501
|
||||
"StableLmForCausalLM": _HfExamplesInfo("stabilityai/stablelm-3b-4e1t"),
|
||||
"Starcoder2ForCausalLM": _HfExamplesInfo("bigcode/starcoder2-3b"),
|
||||
"SolarForCausalLM": _HfExamplesInfo("upstage/solar-pro-preview-instruct"),
|
||||
"TeleChat2ForCausalLM": _HfExamplesInfo("Tele-AI/TeleChat2-3B",
|
||||
@ -330,8 +329,7 @@ _MULTIMODAL_EXAMPLE_MODELS = {
|
||||
"AriaForConditionalGeneration": _HfExamplesInfo("rhymes-ai/Aria"),
|
||||
"AyaVisionForConditionalGeneration": _HfExamplesInfo("CohereForAI/aya-vision-8b"), # noqa: E501
|
||||
"Blip2ForConditionalGeneration": _HfExamplesInfo("Salesforce/blip2-opt-2.7b", # noqa: E501
|
||||
extras={"6b": "Salesforce/blip2-opt-6.7b"}, # noqa: E501
|
||||
v0_only=True),
|
||||
extras={"6b": "Salesforce/blip2-opt-6.7b"}), # noqa: E501
|
||||
"ChameleonForConditionalGeneration": _HfExamplesInfo("facebook/chameleon-7b"), # noqa: E501
|
||||
"DeepseekVLV2ForCausalLM": _HfExamplesInfo("deepseek-ai/deepseek-vl2-tiny", # noqa: E501
|
||||
extras={"fork": "Isotr0py/deepseek-vl2-tiny"}, # noqa: E501
|
||||
@ -359,8 +357,7 @@ _MULTIMODAL_EXAMPLE_MODELS = {
|
||||
trust_remote_code=True),
|
||||
"KimiVLForConditionalGeneration": _HfExamplesInfo("moonshotai/Kimi-VL-A3B-Instruct", # noqa: E501
|
||||
extras={"thinking": "moonshotai/Kimi-VL-A3B-Thinking"}, # noqa: E501
|
||||
trust_remote_code=True,
|
||||
v0_only=True),
|
||||
trust_remote_code=True),
|
||||
"Llama4ForConditionalGeneration": _HfExamplesInfo("meta-llama/Llama-4-Scout-17B-16E-Instruct", # noqa: E501
|
||||
max_model_len=10240),
|
||||
"LlavaForConditionalGeneration": _HfExamplesInfo("llava-hf/llava-1.5-7b-hf",
|
||||
|
||||
@ -22,7 +22,8 @@ def test_can_initialize(model_arch: str, monkeypatch: pytest.MonkeyPatch):
|
||||
model_info.check_transformers_version(on_fail="skip")
|
||||
|
||||
# FIXME: Possible memory leak in the previous tests?
|
||||
if model_arch == "GraniteSpeechForConditionalGeneration":
|
||||
if model_arch in ("GraniteSpeechForConditionalGeneration",
|
||||
"KimiVLForConditionalGeneration"):
|
||||
pytest.skip("Avoid OOM")
|
||||
|
||||
# Avoid OOM and reduce initialization time by only using 1 layer
|
||||
|
||||
@ -310,7 +310,8 @@ class MultiHeadAttention(nn.Module):
|
||||
# currently, only torch_sdpa is supported on rocm
|
||||
self.attn_backend = _Backend.TORCH_SDPA
|
||||
else:
|
||||
if backend in {_Backend.FLASH_ATTN, _Backend.FLASH_ATTN_VLLM_V1}:
|
||||
if backend in (_Backend.FLASH_ATTN, _Backend.FLASH_ATTN_VLLM_V1,
|
||||
_Backend.FLEX_ATTENTION):
|
||||
backend = _Backend.XFORMERS
|
||||
|
||||
self.attn_backend = backend if backend in {
|
||||
|
||||
@ -4,7 +4,7 @@
|
||||
import os
|
||||
from contextlib import contextmanager
|
||||
from functools import cache
|
||||
from typing import Generator, Optional, Type
|
||||
from typing import Generator, Optional, Union
|
||||
|
||||
import torch
|
||||
|
||||
@ -79,6 +79,33 @@ def get_global_forced_attn_backend() -> Optional[_Backend]:
|
||||
return forced_attn_backend
|
||||
|
||||
|
||||
def supports_head_size(
|
||||
attn_backend: Union[str, type[AttentionBackend]],
|
||||
head_size: int,
|
||||
) -> bool:
|
||||
if isinstance(attn_backend, str):
|
||||
try:
|
||||
attn_backend = resolve_obj_by_qualname(attn_backend)
|
||||
except ImportError:
|
||||
return False
|
||||
|
||||
assert isinstance(attn_backend, type)
|
||||
|
||||
# TODO: Update the interface once V0 is removed
|
||||
if get_supported_head_sizes := getattr(attn_backend,
|
||||
"get_supported_head_sizes", None):
|
||||
return head_size in get_supported_head_sizes()
|
||||
if validate_head_size := getattr(attn_backend, "validate_head_size", None):
|
||||
try:
|
||||
validate_head_size(head_size)
|
||||
return True
|
||||
except Exception:
|
||||
return False
|
||||
|
||||
raise NotImplementedError(f"{attn_backend.__name__} does not support "
|
||||
"head size validation")
|
||||
|
||||
|
||||
def get_attn_backend(
|
||||
head_size: int,
|
||||
dtype: torch.dtype,
|
||||
@ -87,7 +114,7 @@ def get_attn_backend(
|
||||
is_attention_free: bool,
|
||||
is_blocksparse: bool = False,
|
||||
use_mla: bool = False,
|
||||
) -> Type[AttentionBackend]:
|
||||
) -> type[AttentionBackend]:
|
||||
"""Selects which attention backend to use and lazily imports it."""
|
||||
# Accessing envs.* behind an @lru_cache decorator can cause the wrong
|
||||
# value to be returned from the cache if the value changes between calls.
|
||||
@ -115,7 +142,7 @@ def _cached_get_attn_backend(
|
||||
is_blocksparse: bool = False,
|
||||
use_v1: bool = False,
|
||||
use_mla: bool = False,
|
||||
) -> Type[AttentionBackend]:
|
||||
) -> type[AttentionBackend]:
|
||||
if is_blocksparse:
|
||||
logger.info("Using BlocksparseFlashAttention backend.")
|
||||
from vllm.attention.backends.blocksparse_attn import (
|
||||
|
||||
@ -2319,7 +2319,7 @@ class SchedulerConfig:
|
||||
|
||||
if self.max_num_batched_tokens > self.max_num_seqs * self.max_model_len:
|
||||
logger.warning(
|
||||
"max_num_batched_tokens (%d) exceeds max_num_seqs"
|
||||
"max_num_batched_tokens (%d) exceeds max_num_seqs "
|
||||
"* max_model_len (%d). This may lead to unexpected behavior.",
|
||||
self.max_num_batched_tokens,
|
||||
self.max_num_seqs * self.max_model_len)
|
||||
|
||||
@ -234,35 +234,44 @@ class CudaPlatformBase(Platform):
|
||||
return ("vllm.attention.backends."
|
||||
"flashmla.FlashMLABackend")
|
||||
if use_v1:
|
||||
FLASHINFER_V1 = "vllm.v1.attention.backends.flashinfer.FlashInferBackend" # noqa: E501
|
||||
FLEX_ATTENTION_V1 = "vllm.v1.attention.backends.flex_attention.FlexAttentionBackend" # noqa: E501
|
||||
TRITON_ATTN_VLLM_V1 = "vllm.v1.attention.backends.triton_attn.TritonAttentionBackend" # noqa: E501
|
||||
FLASH_ATTN_V1 = "vllm.v1.attention.backends.flash_attn.FlashAttentionBackend" # noqa: E501
|
||||
|
||||
if selected_backend == _Backend.FLASHINFER:
|
||||
logger.info_once("Using FlashInfer backend on V1 engine.")
|
||||
return "vllm.v1.attention.backends.flashinfer.FlashInferBackend"
|
||||
return FLASHINFER_V1
|
||||
elif selected_backend == _Backend.FLEX_ATTENTION:
|
||||
logger.info("Using FlexAttenion backend on V1 engine.")
|
||||
return "vllm.v1.attention.backends.flex_attention.FlexAttentionBackend" # noqa: E501
|
||||
logger.info_once("Using FlexAttention backend on V1 engine.")
|
||||
return FLEX_ATTENTION_V1
|
||||
elif selected_backend == _Backend.TRITON_ATTN_VLLM_V1:
|
||||
logger.info_once("Using Triton backend on V1 engine.")
|
||||
return ("vllm.v1.attention.backends."
|
||||
"triton_attn.TritonAttentionBackend")
|
||||
return TRITON_ATTN_VLLM_V1
|
||||
elif selected_backend == _Backend.FLASH_ATTN:
|
||||
logger.info_once("Using Flash Attention backend on V1 engine.")
|
||||
return ("vllm.v1.attention.backends."
|
||||
"flash_attn.FlashAttentionBackend")
|
||||
return FLASH_ATTN_V1
|
||||
|
||||
from vllm.attention.selector import supports_head_size
|
||||
|
||||
# Default backends for V1 engine
|
||||
# Prefer FlashInfer for Blackwell GPUs if installed
|
||||
# FP32 is only supported by FlexAttention
|
||||
if dtype not in (torch.float16, torch.bfloat16):
|
||||
logger.info_once(
|
||||
f"Using FlexAttenion backend for {dtype} on V1 engine.")
|
||||
return "vllm.v1.attention.backends.flex_attention.FlexAttentionBackend" # noqa: E501
|
||||
if cls.is_device_capability(100):
|
||||
"Using FlexAttention backend for %s on V1 engine.",
|
||||
dtype,
|
||||
)
|
||||
return FLEX_ATTENTION_V1
|
||||
|
||||
# Prefer FlashInfer for Blackwell GPUs if installed
|
||||
if cls.is_device_capability(100) and \
|
||||
supports_head_size(FLASHINFER_V1, head_size):
|
||||
try:
|
||||
import flashinfer # noqa: F401
|
||||
logger.info_once(
|
||||
"Using FlashInfer backend on V1 engine by default for "
|
||||
"Blackwell (SM 10.0) GPUs.")
|
||||
return ("vllm.v1.attention.backends."
|
||||
"flashinfer.FlashInferBackend")
|
||||
return FLASHINFER_V1
|
||||
except ImportError:
|
||||
logger.info_once(
|
||||
"FlashInfer failed to import for V1 engine on "
|
||||
@ -270,10 +279,13 @@ class CudaPlatformBase(Platform):
|
||||
"install FlashInfer for better performance.")
|
||||
pass
|
||||
# FlashAttention is the default for SM 8.0+ GPUs
|
||||
if cls.has_device_capability(80):
|
||||
if cls.has_device_capability(80) and \
|
||||
supports_head_size(FLASH_ATTN_V1, head_size):
|
||||
logger.info_once("Using Flash Attention backend on V1 engine.")
|
||||
return ("vllm.v1.attention.backends."
|
||||
"flash_attn.FlashAttentionBackend")
|
||||
return FLASH_ATTN_V1
|
||||
|
||||
logger.info_once("Using FlexAttention backend on V1 engine.")
|
||||
return FLEX_ATTENTION_V1
|
||||
|
||||
# Backends for V0 engine
|
||||
if selected_backend == _Backend.FLASHINFER:
|
||||
|
||||
@ -3,7 +3,8 @@
|
||||
import numpy as np
|
||||
import torch
|
||||
|
||||
from vllm.attention.backends.abstract import AttentionMetadata
|
||||
from vllm.attention.backends.abstract import (AttentionBackend,
|
||||
AttentionMetadata)
|
||||
from vllm.attention.backends.torch_sdpa import (TorchSDPABackendImpl,
|
||||
TorchSDPAMetadata)
|
||||
from vllm.attention.backends.utils import CommonAttentionState
|
||||
@ -17,9 +18,24 @@ from vllm.v1.worker.cpu_model_runner import CPUModelRunner
|
||||
from vllm.v1.worker.gpu_input_batch import InputBatch
|
||||
|
||||
|
||||
class TorchSDPABackend:
|
||||
class TorchSDPABackend(AttentionBackend):
|
||||
accept_output_buffer: bool = False
|
||||
|
||||
@classmethod
|
||||
def get_supported_head_sizes(cls) -> list[int]:
|
||||
return PagedAttention.get_supported_head_sizes()
|
||||
|
||||
@classmethod
|
||||
def validate_head_size(cls, head_size: int) -> None:
|
||||
supported_head_sizes = cls.get_supported_head_sizes()
|
||||
if head_size not in supported_head_sizes:
|
||||
attn_type = cls.__name__.removesuffix("Backend")
|
||||
raise ValueError(
|
||||
f"Head size {head_size} is not supported by {attn_type}. "
|
||||
f"Supported head sizes are: {supported_head_sizes}. "
|
||||
"Set VLLM_ATTENTION_BACKEND=FLEX_ATTENTION to use "
|
||||
"FlexAttention backend which supports all head sizes.")
|
||||
|
||||
@staticmethod
|
||||
def get_name() -> str:
|
||||
return "TORCH_SDPA_VLLM_V1"
|
||||
|
||||
@ -44,10 +44,21 @@ class FlashAttentionBackend(AttentionBackend):
|
||||
|
||||
accept_output_buffer: bool = True
|
||||
|
||||
@staticmethod
|
||||
def get_supported_head_sizes() -> list[int]:
|
||||
@classmethod
|
||||
def get_supported_head_sizes(cls) -> list[int]:
|
||||
return [32, 64, 96, 128, 160, 192, 224, 256]
|
||||
|
||||
@classmethod
|
||||
def validate_head_size(cls, head_size: int) -> None:
|
||||
supported_head_sizes = cls.get_supported_head_sizes()
|
||||
if head_size not in supported_head_sizes:
|
||||
attn_type = cls.__name__.removesuffix("Backend")
|
||||
raise ValueError(
|
||||
f"Head size {head_size} is not supported by {attn_type}. "
|
||||
f"Supported head sizes are: {supported_head_sizes}. "
|
||||
"Set VLLM_ATTENTION_BACKEND=FLEX_ATTENTION to use "
|
||||
"FlexAttention backend which supports all head sizes.")
|
||||
|
||||
@staticmethod
|
||||
def get_name() -> str:
|
||||
return "FLASH_ATTN_VLLM_V1"
|
||||
@ -416,12 +427,7 @@ class FlashAttentionImpl(AttentionImpl):
|
||||
|
||||
self.num_queries_per_kv = self.num_heads // self.num_kv_heads
|
||||
|
||||
support_head_sizes = FlashAttentionBackend.get_supported_head_sizes()
|
||||
if head_size not in support_head_sizes:
|
||||
raise ValueError(
|
||||
f"Head size {head_size} is not supported by FlashAttention. "
|
||||
f"Supported head sizes are: {support_head_sizes}. "
|
||||
"Set VLLM_USE_V1=0 to use another attention backend.")
|
||||
FlashAttentionBackend.validate_head_size(head_size)
|
||||
|
||||
if attn_type != AttentionType.DECODER:
|
||||
raise NotImplementedError("Encoder self-attention and "
|
||||
|
||||
@ -38,10 +38,22 @@ class FlashInferBackend(AttentionBackend):
|
||||
|
||||
accept_output_buffer: bool = True
|
||||
|
||||
@staticmethod
|
||||
def get_supported_head_sizes() -> list[int]:
|
||||
@classmethod
|
||||
def get_supported_head_sizes(cls) -> list[int]:
|
||||
# https://github.com/flashinfer-ai/flashinfer/blob/3d55c71a62052c590c130897d3a3db49b14fcc34/include/flashinfer/utils.cuh#L157
|
||||
return [64, 128, 256]
|
||||
|
||||
@classmethod
|
||||
def validate_head_size(cls, head_size: int) -> None:
|
||||
supported_head_sizes = cls.get_supported_head_sizes()
|
||||
if head_size not in supported_head_sizes:
|
||||
attn_type = cls.__name__.removesuffix("Backend")
|
||||
raise ValueError(
|
||||
f"Head size {head_size} is not supported by {attn_type}. "
|
||||
f"Supported head sizes are: {supported_head_sizes}. "
|
||||
"Set VLLM_ATTENTION_BACKEND=FLEX_ATTENTION to use "
|
||||
"FlexAttention backend which supports all head sizes.")
|
||||
|
||||
@staticmethod
|
||||
def get_name() -> str:
|
||||
return "FLASHINFER_VLLM_V1"
|
||||
@ -207,14 +219,8 @@ class FlashInferMetadata:
|
||||
return self.qo_indptr
|
||||
|
||||
def __post_init__(self):
|
||||
# Refer to
|
||||
# https://github.com/flashinfer-ai/flashinfer/blob/3d55c71a62052c590c130897d3a3db49b14fcc34/include/flashinfer/utils.cuh#L157
|
||||
supported_head_sizes = FlashInferBackend.get_supported_head_sizes()
|
||||
if self.head_dim is not None and self.head_dim \
|
||||
not in supported_head_sizes:
|
||||
raise ValueError(
|
||||
f"Only {supported_head_sizes} are supported for head_dim,",
|
||||
f" received {self.head_dim}.")
|
||||
if self.head_dim is not None:
|
||||
FlashInferBackend.validate_head_size(self.head_dim)
|
||||
|
||||
|
||||
class FlashInferMetadataBuilder(AttentionMetadataBuilder[FlashInferMetadata]):
|
||||
|
||||
@ -1,7 +1,7 @@
|
||||
# SPDX-License-Identifier: Apache-2.0
|
||||
# SPDX-FileCopyrightText: Copyright contributors to the vLLM project
|
||||
"""Attention layer with FlashAttention."""
|
||||
|
||||
from collections import defaultdict
|
||||
from dataclasses import dataclass
|
||||
from typing import TYPE_CHECKING, Any, Optional
|
||||
|
||||
@ -21,9 +21,6 @@ from vllm.v1.attention.backends.utils import (AttentionMetadataBuilder,
|
||||
from vllm.v1.kv_cache_interface import AttentionSpec
|
||||
from vllm.v1.worker.block_table import BlockTable
|
||||
|
||||
if current_platform.is_cuda():
|
||||
pass
|
||||
|
||||
logger = init_logger(__name__)
|
||||
|
||||
if TYPE_CHECKING:
|
||||
@ -45,9 +42,9 @@ def _offsets_to_doc_ids_tensor(offsets: torch.Tensor) -> torch.Tensor:
|
||||
class FlexAttentionBackend(AttentionBackend):
|
||||
accept_output_buffer: bool = True
|
||||
|
||||
@staticmethod
|
||||
def get_supported_head_sizes() -> list[int]:
|
||||
return [16, 32, 64, 96, 128, 160, 192, 224, 256]
|
||||
@classmethod
|
||||
def validate_head_size(cls, head_size: int) -> None:
|
||||
return # FlexAttention supports any head size
|
||||
|
||||
@staticmethod
|
||||
def get_name() -> str:
|
||||
@ -384,12 +381,8 @@ class FlexAttentionImpl(AttentionImpl):
|
||||
raise NotImplementedError(
|
||||
"FlexAttention does not support kv sharing yet.")
|
||||
|
||||
support_head_sizes = FlexAttentionBackend.get_supported_head_sizes()
|
||||
if head_size not in support_head_sizes:
|
||||
raise ValueError(
|
||||
f"Head size {head_size} is not supported by FlashAttention. "
|
||||
f"Supported head sizes are: {support_head_sizes}. "
|
||||
"Set VLLM_USE_V1=0 to use another attention backend.")
|
||||
FlexAttentionBackend.validate_head_size(head_size)
|
||||
|
||||
if is_quantized_kv_cache(self.kv_cache_dtype):
|
||||
raise NotImplementedError(
|
||||
"FlexAttention does not support quantized kv-cache. Yet")
|
||||
@ -464,12 +457,20 @@ class FlexAttentionImpl(AttentionImpl):
|
||||
# Doesn't work for now -> constraint violation
|
||||
# torch._dynamo.try_mark_dynamic(query, 2)
|
||||
|
||||
# default M=64, N=64 may run out of shared memory on
|
||||
# some GPUs with fp32, so we use smaller M and N.
|
||||
extra_kernel_options = {
|
||||
"BLOCK_M": 32,
|
||||
"BLOCK_N": 32
|
||||
} if query.dtype == torch.float32 else {}
|
||||
# default M=64, N=64 may run out of shared memory on some GPUs
|
||||
# TODO: Explicit configs for each GPU?
|
||||
# Not sure how to calculate the shared memory requirement
|
||||
extra_kernel_options = defaultdict[str, int](lambda: 64)
|
||||
if query.dtype == torch.float32:
|
||||
extra_kernel_options["BLOCK_M"] //= 2
|
||||
extra_kernel_options["BLOCK_N"] //= 2
|
||||
if current_platform.is_cuda():
|
||||
device_props = torch.cuda.get_device_properties()
|
||||
max_shared_memory = device_props.shared_memory_per_block_optin
|
||||
if max_shared_memory < 144 * 1024:
|
||||
extra_kernel_options["BLOCK_M"] //= 2
|
||||
extra_kernel_options["BLOCK_N"] //= 2
|
||||
|
||||
out = flex_attention_compiled(
|
||||
query,
|
||||
key_cache,
|
||||
|
||||
@ -254,10 +254,21 @@ class MLACommonBackend(AttentionBackend):
|
||||
) -> tuple[int, ...]:
|
||||
return (num_blocks, block_size, head_size)
|
||||
|
||||
@staticmethod
|
||||
def get_supported_head_sizes() -> list[int]:
|
||||
@classmethod
|
||||
def get_supported_head_sizes(cls) -> list[int]:
|
||||
return [576]
|
||||
|
||||
@classmethod
|
||||
def validate_head_size(cls, head_size: int) -> None:
|
||||
supported_head_sizes = cls.get_supported_head_sizes()
|
||||
if head_size not in supported_head_sizes:
|
||||
attn_type = cls.__name__.removesuffix("Backend")
|
||||
raise ValueError(
|
||||
f"Head size {head_size} is not supported by {attn_type}. "
|
||||
f"Supported head sizes are: {supported_head_sizes}. "
|
||||
"Set VLLM_ATTENTION_BACKEND=FLEX_ATTENTION to use "
|
||||
"FlexAttention backend which supports all head sizes.")
|
||||
|
||||
|
||||
@dataclass
|
||||
class MLACommonPrefillMetadata:
|
||||
@ -320,12 +331,8 @@ class MLACommonMetadata(Generic[D]):
|
||||
prefill: Optional[MLACommonPrefillMetadata] = None
|
||||
|
||||
def __post_init__(self):
|
||||
supported_head_sizes = MLACommonBackend.get_supported_head_sizes()
|
||||
if self.head_dim is not None and self.head_dim \
|
||||
not in supported_head_sizes:
|
||||
raise ValueError(
|
||||
f"Only {supported_head_sizes} are supported for head_dim,",
|
||||
f"received {self.head_dim}.")
|
||||
if self.head_dim is not None:
|
||||
MLACommonBackend.validate_head_size(self.head_dim)
|
||||
|
||||
|
||||
M = TypeVar("M", bound=MLACommonMetadata)
|
||||
|
||||
@ -314,10 +314,21 @@ class AiterFlashAttentionBackend(AttentionBackend):
|
||||
|
||||
accept_output_buffer: bool = True
|
||||
|
||||
@staticmethod
|
||||
def get_supported_head_sizes() -> list[int]:
|
||||
@classmethod
|
||||
def get_supported_head_sizes(cls) -> list[int]:
|
||||
return [32, 64, 96, 128, 160, 192, 224, 256]
|
||||
|
||||
@classmethod
|
||||
def validate_head_size(cls, head_size: int) -> None:
|
||||
supported_head_sizes = cls.get_supported_head_sizes()
|
||||
if head_size not in supported_head_sizes:
|
||||
attn_type = cls.__name__.removesuffix("Backend")
|
||||
raise ValueError(
|
||||
f"Head size {head_size} is not supported by {attn_type}. "
|
||||
f"Supported head sizes are: {supported_head_sizes}. "
|
||||
"Set VLLM_ATTENTION_BACKEND=FLEX_ATTENTION to use "
|
||||
"FlexAttention backend which supports all head sizes.")
|
||||
|
||||
@staticmethod
|
||||
def get_name() -> str:
|
||||
return "FLASH_ATTN_VLLM_V1"
|
||||
@ -428,14 +439,7 @@ class AiterFlashAttentionImpl(AttentionImpl):
|
||||
assert self.num_heads % self.num_kv_heads == 0
|
||||
self.num_queries_per_kv = self.num_heads // self.num_kv_heads
|
||||
|
||||
support_head_sizes = \
|
||||
AiterFlashAttentionBackend.get_supported_head_sizes()
|
||||
if head_size not in support_head_sizes:
|
||||
raise ValueError(
|
||||
f"Head size {head_size} is not supported by "
|
||||
"AiterFlashAttention. "
|
||||
f"Supported head sizes are: {support_head_sizes}. "
|
||||
"Set VLLM_USE_V1=0 to use another attention backend.")
|
||||
AiterFlashAttentionBackend.validate_head_size(head_size)
|
||||
|
||||
if attn_type != AttentionType.DECODER:
|
||||
raise NotImplementedError("Encoder self-attention and "
|
||||
|
||||
@ -190,10 +190,21 @@ class TritonAttentionBackend(AttentionBackend):
|
||||
|
||||
accept_output_buffer: bool = True
|
||||
|
||||
@staticmethod
|
||||
def get_supported_head_sizes() -> list[int]:
|
||||
@classmethod
|
||||
def get_supported_head_sizes(cls) -> list[int]:
|
||||
return [32, 64, 96, 128, 160, 192, 224, 256]
|
||||
|
||||
@classmethod
|
||||
def validate_head_size(cls, head_size: int) -> None:
|
||||
supported_head_sizes = cls.get_supported_head_sizes()
|
||||
if head_size not in supported_head_sizes:
|
||||
attn_type = cls.__name__.removesuffix("Backend")
|
||||
raise ValueError(
|
||||
f"Head size {head_size} is not supported by {attn_type}. "
|
||||
f"Supported head sizes are: {supported_head_sizes}. "
|
||||
"Set VLLM_ATTENTION_BACKEND=FLEX_ATTENTION to use "
|
||||
"FlexAttention backend which supports all head sizes.")
|
||||
|
||||
@staticmethod
|
||||
def get_name() -> str:
|
||||
return "TRITON_ATTN_VLLM_V1"
|
||||
@ -268,11 +279,7 @@ class TritonAttentionImpl(AttentionImpl):
|
||||
|
||||
self.num_queries_per_kv = self.num_heads // self.num_kv_heads
|
||||
|
||||
support_head_sizes = TritonAttentionBackend.get_supported_head_sizes()
|
||||
if head_size not in support_head_sizes:
|
||||
raise ValueError(
|
||||
f"Head size {head_size} is not supported by TritonAttention. "
|
||||
f"Supported head sizes are: {support_head_sizes}.")
|
||||
TritonAttentionBackend.validate_head_size(head_size)
|
||||
|
||||
if attn_type != AttentionType.DECODER:
|
||||
raise NotImplementedError("Encoder self-attention and "
|
||||
|
||||
@ -12,8 +12,8 @@ from vllm.logger import init_logger
|
||||
from vllm.model_executor.model_loader import get_model
|
||||
from vllm.model_executor.models import supports_multimodal
|
||||
from vllm.model_executor.models.llama_eagle3 import Eagle3LlamaForCausalLM
|
||||
from vllm.v1.attention.backends.flash_attn import (CommonAttentionMetadata,
|
||||
FlashAttentionMetadata)
|
||||
from vllm.v1.attention.backends.flash_attn import FlashAttentionMetadata
|
||||
from vllm.v1.attention.backends.utils import CommonAttentionMetadata
|
||||
from vllm.v1.kv_cache_interface import KVCacheConfig
|
||||
from vllm.v1.sample.metadata import SamplingMetadata
|
||||
from vllm.v1.spec_decode.utils import prepare_eagle_input_kernel
|
||||
|
||||
Loading…
x
Reference in New Issue
Block a user