mirror of
https://git.datalinker.icu/vllm-project/vllm.git
synced 2026-06-02 21:37:58 +08:00
[V1] Support any head size for FlexAttention backend (#20467)
Signed-off-by: DarkLight1337 <tlleungac@connect.ust.hk>
This commit is contained in:
parent
e202dd2736
commit
9fb52e523a
@ -107,10 +107,9 @@ fi
|
|||||||
|
|
||||||
if [[ $commands == *" kernels/attention"* ]]; then
|
if [[ $commands == *" kernels/attention"* ]]; then
|
||||||
commands="${commands} \
|
commands="${commands} \
|
||||||
--ignore=kernels/attention/stest_attention_selector.py \
|
--ignore=kernels/attention/test_attention_selector.py \
|
||||||
--ignore=kernels/attention/test_blocksparse_attention.py \
|
--ignore=kernels/attention/test_blocksparse_attention.py \
|
||||||
--ignore=kernels/attention/test_encoder_decoder_attn.py \
|
--ignore=kernels/attention/test_encoder_decoder_attn.py \
|
||||||
--ignore=kernels/attention/test_attention_selector.py \
|
|
||||||
--ignore=kernels/attention/test_flash_attn.py \
|
--ignore=kernels/attention/test_flash_attn.py \
|
||||||
--ignore=kernels/attention/test_flashinfer.py \
|
--ignore=kernels/attention/test_flashinfer.py \
|
||||||
--ignore=kernels/attention/test_prefix_prefill.py \
|
--ignore=kernels/attention/test_prefix_prefill.py \
|
||||||
|
|||||||
@ -626,9 +626,6 @@ Specified using `--task generate`.
|
|||||||
!!! note
|
!!! note
|
||||||
Only `InternVLChatModel` with Qwen2.5 text backbone (`OpenGVLab/InternVL3-2B`, `OpenGVLab/InternVL2.5-1B` etc) has video inputs support currently.
|
Only `InternVLChatModel` with Qwen2.5 text backbone (`OpenGVLab/InternVL3-2B`, `OpenGVLab/InternVL2.5-1B` etc) has video inputs support currently.
|
||||||
|
|
||||||
!!! note
|
|
||||||
`h2oai/h2ovl-mississippi-2b` will be available in V1 once we support head size 80.
|
|
||||||
|
|
||||||
!!! note
|
!!! note
|
||||||
To use `TIGER-Lab/Mantis-8B-siglip-llama3`, you have to pass `--hf_overrides '{"architectures": ["MantisForConditionalGeneration"]}'` when running vLLM.
|
To use `TIGER-Lab/Mantis-8B-siglip-llama3`, you have to pass `--hf_overrides '{"architectures": ["MantisForConditionalGeneration"]}'` when running vLLM.
|
||||||
|
|
||||||
@ -671,11 +668,8 @@ Specified using `--task generate`.
|
|||||||
Our PaliGemma implementations have the same problem as Gemma 3 (see above) for both V0 and V1.
|
Our PaliGemma implementations have the same problem as Gemma 3 (see above) for both V0 and V1.
|
||||||
|
|
||||||
!!! note
|
!!! note
|
||||||
To use Qwen2.5-Omni, you have to install Hugging Face Transformers library from source via
|
For Qwen2.5-Omni, reading audio from video pre-processing (`--mm-processor-kwargs '{"use_audio_in_video": true}'`)
|
||||||
`pip install git+https://github.com/huggingface/transformers.git`.
|
is currently supported on V0 (but not V1), because overlapping modalities is not yet supported in V1.
|
||||||
|
|
||||||
Read audio from video pre-processing is currently supported on V0 (but not V1), because overlapping modalities is not yet supported in V1.
|
|
||||||
`--mm-processor-kwargs '{"use_audio_in_video": true}'`.
|
|
||||||
|
|
||||||
#### Transcription
|
#### Transcription
|
||||||
|
|
||||||
|
|||||||
@ -98,7 +98,7 @@ def run_blip2(questions: list[str], modality: str) -> ModelRequestData:
|
|||||||
# See https://huggingface.co/Salesforce/blip2-opt-2.7b/discussions/15#64ff02f3f8cf9e4f5b038262 #noqa
|
# See https://huggingface.co/Salesforce/blip2-opt-2.7b/discussions/15#64ff02f3f8cf9e4f5b038262 #noqa
|
||||||
prompts = [f"Question: {question} Answer:" for question in questions]
|
prompts = [f"Question: {question} Answer:" for question in questions]
|
||||||
engine_args = EngineArgs(
|
engine_args = EngineArgs(
|
||||||
model="Salesforce/blip2-opt-6.7b",
|
model="Salesforce/blip2-opt-2.7b",
|
||||||
limit_mm_per_prompt={modality: 1},
|
limit_mm_per_prompt={modality: 1},
|
||||||
)
|
)
|
||||||
|
|
||||||
@ -971,7 +971,7 @@ def run_pixtral_hf(questions: list[str], modality: str) -> ModelRequestData:
|
|||||||
)
|
)
|
||||||
|
|
||||||
|
|
||||||
# Qwen
|
# Qwen-VL
|
||||||
def run_qwen_vl(questions: list[str], modality: str) -> ModelRequestData:
|
def run_qwen_vl(questions: list[str], modality: str) -> ModelRequestData:
|
||||||
assert modality == "image"
|
assert modality == "image"
|
||||||
|
|
||||||
|
|||||||
@ -172,7 +172,7 @@ def test_env(
|
|||||||
expected = "FLASHINFER_VLLM_V1" if use_v1 else name
|
expected = "FLASHINFER_VLLM_V1" if use_v1 else name
|
||||||
assert backend.get_name() == expected
|
assert backend.get_name() == expected
|
||||||
else:
|
else:
|
||||||
backend = get_attn_backend(16,
|
backend = get_attn_backend(32,
|
||||||
torch.float16,
|
torch.float16,
|
||||||
torch.float16,
|
torch.float16,
|
||||||
block_size,
|
block_size,
|
||||||
@ -181,6 +181,17 @@ def test_env(
|
|||||||
expected = "FLASH_ATTN_VLLM_V1" if use_v1 else name
|
expected = "FLASH_ATTN_VLLM_V1" if use_v1 else name
|
||||||
assert backend.get_name() == expected
|
assert backend.get_name() == expected
|
||||||
|
|
||||||
|
if use_v1:
|
||||||
|
backend = get_attn_backend(16,
|
||||||
|
torch.float16,
|
||||||
|
torch.float16,
|
||||||
|
block_size,
|
||||||
|
False,
|
||||||
|
use_mla=use_mla)
|
||||||
|
assert backend.get_name() == "FLEX_ATTENTION", (
|
||||||
|
"Should fallback to FlexAttention if head size is "
|
||||||
|
"not supported by FlashAttention")
|
||||||
|
|
||||||
|
|
||||||
@pytest.mark.parametrize("device", ["cpu", "cuda"])
|
@pytest.mark.parametrize("device", ["cpu", "cuda"])
|
||||||
@pytest.mark.parametrize("use_v1", [True, False])
|
@pytest.mark.parametrize("use_v1", [True, False])
|
||||||
|
|||||||
@ -33,9 +33,6 @@ if current_platform.is_rocm():
|
|||||||
os.environ["VLLM_USE_TRITON_FLASH_ATTN"] = "0"
|
os.environ["VLLM_USE_TRITON_FLASH_ATTN"] = "0"
|
||||||
|
|
||||||
REQUIRES_V0_MODELS = [
|
REQUIRES_V0_MODELS = [
|
||||||
# V1 Test: no way to fall back for head_dim = 80
|
|
||||||
# https://github.com/vllm-project/vllm/issues/14524
|
|
||||||
"qwen_vl",
|
|
||||||
# V1 Test: not enough KV cache space in C1.
|
# V1 Test: not enough KV cache space in C1.
|
||||||
"fuyu",
|
"fuyu",
|
||||||
]
|
]
|
||||||
@ -221,8 +218,7 @@ VLM_TEST_SETTINGS = {
|
|||||||
marks=[large_gpu_mark(min_gb=32)],
|
marks=[large_gpu_mark(min_gb=32)],
|
||||||
),
|
),
|
||||||
"blip2": VLMTestInfo(
|
"blip2": VLMTestInfo(
|
||||||
# TODO: Change back to 2.7b once head_dim = 80 is supported
|
models=["Salesforce/blip2-opt-2.7b"],
|
||||||
models=["Salesforce/blip2-opt-6.7b"],
|
|
||||||
test_type=VLMTestType.IMAGE,
|
test_type=VLMTestType.IMAGE,
|
||||||
prompt_formatter=lambda img_prompt: f"Question: {img_prompt} Answer:",
|
prompt_formatter=lambda img_prompt: f"Question: {img_prompt} Answer:",
|
||||||
img_idx_to_prompt=lambda idx: "",
|
img_idx_to_prompt=lambda idx: "",
|
||||||
@ -340,8 +336,7 @@ VLM_TEST_SETTINGS = {
|
|||||||
"h2ovl": VLMTestInfo(
|
"h2ovl": VLMTestInfo(
|
||||||
models = [
|
models = [
|
||||||
"h2oai/h2ovl-mississippi-800m",
|
"h2oai/h2ovl-mississippi-800m",
|
||||||
# TODO: Re-enable once head_dim = 80 is supported
|
"h2oai/h2ovl-mississippi-2b",
|
||||||
# "h2oai/h2ovl-mississippi-2b",
|
|
||||||
],
|
],
|
||||||
test_type=(VLMTestType.IMAGE, VLMTestType.MULTI_IMAGE),
|
test_type=(VLMTestType.IMAGE, VLMTestType.MULTI_IMAGE),
|
||||||
prompt_formatter=lambda img_prompt: f"<|prompt|>{img_prompt}<|end|><|answer|>", # noqa: E501
|
prompt_formatter=lambda img_prompt: f"<|prompt|>{img_prompt}<|end|><|answer|>", # noqa: E501
|
||||||
|
|||||||
@ -83,7 +83,7 @@ MODELS = [
|
|||||||
QWEN2_CONFIG,
|
QWEN2_CONFIG,
|
||||||
PHI3_CONFIG,
|
PHI3_CONFIG,
|
||||||
GPT2_CONFIG,
|
GPT2_CONFIG,
|
||||||
# STABLELM_CONFIG, # enable this when v1 support head_size=80
|
STABLELM_CONFIG,
|
||||||
DOLPHIN_CONFIG,
|
DOLPHIN_CONFIG,
|
||||||
# STARCODER_CONFIG, # broken
|
# STARCODER_CONFIG, # broken
|
||||||
]
|
]
|
||||||
|
|||||||
@ -240,8 +240,9 @@ _TEXT_GENERATION_EXAMPLE_MODELS = {
|
|||||||
"OrionForCausalLM": _HfExamplesInfo("OrionStarAI/Orion-14B-Chat",
|
"OrionForCausalLM": _HfExamplesInfo("OrionStarAI/Orion-14B-Chat",
|
||||||
trust_remote_code=True),
|
trust_remote_code=True),
|
||||||
"PersimmonForCausalLM": _HfExamplesInfo("adept/persimmon-8b-chat"),
|
"PersimmonForCausalLM": _HfExamplesInfo("adept/persimmon-8b-chat"),
|
||||||
"PhiForCausalLM": _HfExamplesInfo("microsoft/phi-2", v0_only=True),
|
"PhiForCausalLM": _HfExamplesInfo("microsoft/phi-2"),
|
||||||
"Phi3ForCausalLM": _HfExamplesInfo("microsoft/Phi-3-mini-4k-instruct"),
|
"Phi3ForCausalLM": _HfExamplesInfo("microsoft/Phi-3-mini-4k-instruct"),
|
||||||
|
# Blocksparse attention not supported in V1 yet
|
||||||
"Phi3SmallForCausalLM": _HfExamplesInfo("microsoft/Phi-3-small-8k-instruct",
|
"Phi3SmallForCausalLM": _HfExamplesInfo("microsoft/Phi-3-small-8k-instruct",
|
||||||
trust_remote_code=True,
|
trust_remote_code=True,
|
||||||
v0_only=True),
|
v0_only=True),
|
||||||
@ -258,10 +259,8 @@ _TEXT_GENERATION_EXAMPLE_MODELS = {
|
|||||||
"Qwen3MoeForCausalLM": _HfExamplesInfo("Qwen/Qwen3-30B-A3B"),
|
"Qwen3MoeForCausalLM": _HfExamplesInfo("Qwen/Qwen3-30B-A3B"),
|
||||||
"Qwen3ForSequenceClassification": _HfExamplesInfo("tomaarsen/Qwen3-Reranker-0.6B-seq-cls"), # noqa: E501
|
"Qwen3ForSequenceClassification": _HfExamplesInfo("tomaarsen/Qwen3-Reranker-0.6B-seq-cls"), # noqa: E501
|
||||||
"RWForCausalLM": _HfExamplesInfo("tiiuae/falcon-40b"),
|
"RWForCausalLM": _HfExamplesInfo("tiiuae/falcon-40b"),
|
||||||
"StableLMEpochForCausalLM": _HfExamplesInfo("stabilityai/stablelm-zephyr-3b", # noqa: E501
|
"StableLMEpochForCausalLM": _HfExamplesInfo("stabilityai/stablelm-zephyr-3b"), # noqa: E501
|
||||||
v0_only=True),
|
"StableLmForCausalLM": _HfExamplesInfo("stabilityai/stablelm-3b-4e1t"),
|
||||||
"StableLmForCausalLM": _HfExamplesInfo("stabilityai/stablelm-3b-4e1t",
|
|
||||||
v0_only=True),
|
|
||||||
"Starcoder2ForCausalLM": _HfExamplesInfo("bigcode/starcoder2-3b"),
|
"Starcoder2ForCausalLM": _HfExamplesInfo("bigcode/starcoder2-3b"),
|
||||||
"SolarForCausalLM": _HfExamplesInfo("upstage/solar-pro-preview-instruct"),
|
"SolarForCausalLM": _HfExamplesInfo("upstage/solar-pro-preview-instruct"),
|
||||||
"TeleChat2ForCausalLM": _HfExamplesInfo("Tele-AI/TeleChat2-3B",
|
"TeleChat2ForCausalLM": _HfExamplesInfo("Tele-AI/TeleChat2-3B",
|
||||||
@ -330,8 +329,7 @@ _MULTIMODAL_EXAMPLE_MODELS = {
|
|||||||
"AriaForConditionalGeneration": _HfExamplesInfo("rhymes-ai/Aria"),
|
"AriaForConditionalGeneration": _HfExamplesInfo("rhymes-ai/Aria"),
|
||||||
"AyaVisionForConditionalGeneration": _HfExamplesInfo("CohereForAI/aya-vision-8b"), # noqa: E501
|
"AyaVisionForConditionalGeneration": _HfExamplesInfo("CohereForAI/aya-vision-8b"), # noqa: E501
|
||||||
"Blip2ForConditionalGeneration": _HfExamplesInfo("Salesforce/blip2-opt-2.7b", # noqa: E501
|
"Blip2ForConditionalGeneration": _HfExamplesInfo("Salesforce/blip2-opt-2.7b", # noqa: E501
|
||||||
extras={"6b": "Salesforce/blip2-opt-6.7b"}, # noqa: E501
|
extras={"6b": "Salesforce/blip2-opt-6.7b"}), # noqa: E501
|
||||||
v0_only=True),
|
|
||||||
"ChameleonForConditionalGeneration": _HfExamplesInfo("facebook/chameleon-7b"), # noqa: E501
|
"ChameleonForConditionalGeneration": _HfExamplesInfo("facebook/chameleon-7b"), # noqa: E501
|
||||||
"DeepseekVLV2ForCausalLM": _HfExamplesInfo("deepseek-ai/deepseek-vl2-tiny", # noqa: E501
|
"DeepseekVLV2ForCausalLM": _HfExamplesInfo("deepseek-ai/deepseek-vl2-tiny", # noqa: E501
|
||||||
extras={"fork": "Isotr0py/deepseek-vl2-tiny"}, # noqa: E501
|
extras={"fork": "Isotr0py/deepseek-vl2-tiny"}, # noqa: E501
|
||||||
@ -359,8 +357,7 @@ _MULTIMODAL_EXAMPLE_MODELS = {
|
|||||||
trust_remote_code=True),
|
trust_remote_code=True),
|
||||||
"KimiVLForConditionalGeneration": _HfExamplesInfo("moonshotai/Kimi-VL-A3B-Instruct", # noqa: E501
|
"KimiVLForConditionalGeneration": _HfExamplesInfo("moonshotai/Kimi-VL-A3B-Instruct", # noqa: E501
|
||||||
extras={"thinking": "moonshotai/Kimi-VL-A3B-Thinking"}, # noqa: E501
|
extras={"thinking": "moonshotai/Kimi-VL-A3B-Thinking"}, # noqa: E501
|
||||||
trust_remote_code=True,
|
trust_remote_code=True),
|
||||||
v0_only=True),
|
|
||||||
"Llama4ForConditionalGeneration": _HfExamplesInfo("meta-llama/Llama-4-Scout-17B-16E-Instruct", # noqa: E501
|
"Llama4ForConditionalGeneration": _HfExamplesInfo("meta-llama/Llama-4-Scout-17B-16E-Instruct", # noqa: E501
|
||||||
max_model_len=10240),
|
max_model_len=10240),
|
||||||
"LlavaForConditionalGeneration": _HfExamplesInfo("llava-hf/llava-1.5-7b-hf",
|
"LlavaForConditionalGeneration": _HfExamplesInfo("llava-hf/llava-1.5-7b-hf",
|
||||||
|
|||||||
@ -22,7 +22,8 @@ def test_can_initialize(model_arch: str, monkeypatch: pytest.MonkeyPatch):
|
|||||||
model_info.check_transformers_version(on_fail="skip")
|
model_info.check_transformers_version(on_fail="skip")
|
||||||
|
|
||||||
# FIXME: Possible memory leak in the previous tests?
|
# FIXME: Possible memory leak in the previous tests?
|
||||||
if model_arch == "GraniteSpeechForConditionalGeneration":
|
if model_arch in ("GraniteSpeechForConditionalGeneration",
|
||||||
|
"KimiVLForConditionalGeneration"):
|
||||||
pytest.skip("Avoid OOM")
|
pytest.skip("Avoid OOM")
|
||||||
|
|
||||||
# Avoid OOM and reduce initialization time by only using 1 layer
|
# Avoid OOM and reduce initialization time by only using 1 layer
|
||||||
|
|||||||
@ -310,7 +310,8 @@ class MultiHeadAttention(nn.Module):
|
|||||||
# currently, only torch_sdpa is supported on rocm
|
# currently, only torch_sdpa is supported on rocm
|
||||||
self.attn_backend = _Backend.TORCH_SDPA
|
self.attn_backend = _Backend.TORCH_SDPA
|
||||||
else:
|
else:
|
||||||
if backend in {_Backend.FLASH_ATTN, _Backend.FLASH_ATTN_VLLM_V1}:
|
if backend in (_Backend.FLASH_ATTN, _Backend.FLASH_ATTN_VLLM_V1,
|
||||||
|
_Backend.FLEX_ATTENTION):
|
||||||
backend = _Backend.XFORMERS
|
backend = _Backend.XFORMERS
|
||||||
|
|
||||||
self.attn_backend = backend if backend in {
|
self.attn_backend = backend if backend in {
|
||||||
|
|||||||
@ -4,7 +4,7 @@
|
|||||||
import os
|
import os
|
||||||
from contextlib import contextmanager
|
from contextlib import contextmanager
|
||||||
from functools import cache
|
from functools import cache
|
||||||
from typing import Generator, Optional, Type
|
from typing import Generator, Optional, Union
|
||||||
|
|
||||||
import torch
|
import torch
|
||||||
|
|
||||||
@ -79,6 +79,33 @@ def get_global_forced_attn_backend() -> Optional[_Backend]:
|
|||||||
return forced_attn_backend
|
return forced_attn_backend
|
||||||
|
|
||||||
|
|
||||||
|
def supports_head_size(
|
||||||
|
attn_backend: Union[str, type[AttentionBackend]],
|
||||||
|
head_size: int,
|
||||||
|
) -> bool:
|
||||||
|
if isinstance(attn_backend, str):
|
||||||
|
try:
|
||||||
|
attn_backend = resolve_obj_by_qualname(attn_backend)
|
||||||
|
except ImportError:
|
||||||
|
return False
|
||||||
|
|
||||||
|
assert isinstance(attn_backend, type)
|
||||||
|
|
||||||
|
# TODO: Update the interface once V0 is removed
|
||||||
|
if get_supported_head_sizes := getattr(attn_backend,
|
||||||
|
"get_supported_head_sizes", None):
|
||||||
|
return head_size in get_supported_head_sizes()
|
||||||
|
if validate_head_size := getattr(attn_backend, "validate_head_size", None):
|
||||||
|
try:
|
||||||
|
validate_head_size(head_size)
|
||||||
|
return True
|
||||||
|
except Exception:
|
||||||
|
return False
|
||||||
|
|
||||||
|
raise NotImplementedError(f"{attn_backend.__name__} does not support "
|
||||||
|
"head size validation")
|
||||||
|
|
||||||
|
|
||||||
def get_attn_backend(
|
def get_attn_backend(
|
||||||
head_size: int,
|
head_size: int,
|
||||||
dtype: torch.dtype,
|
dtype: torch.dtype,
|
||||||
@ -87,7 +114,7 @@ def get_attn_backend(
|
|||||||
is_attention_free: bool,
|
is_attention_free: bool,
|
||||||
is_blocksparse: bool = False,
|
is_blocksparse: bool = False,
|
||||||
use_mla: bool = False,
|
use_mla: bool = False,
|
||||||
) -> Type[AttentionBackend]:
|
) -> type[AttentionBackend]:
|
||||||
"""Selects which attention backend to use and lazily imports it."""
|
"""Selects which attention backend to use and lazily imports it."""
|
||||||
# Accessing envs.* behind an @lru_cache decorator can cause the wrong
|
# Accessing envs.* behind an @lru_cache decorator can cause the wrong
|
||||||
# value to be returned from the cache if the value changes between calls.
|
# value to be returned from the cache if the value changes between calls.
|
||||||
@ -115,7 +142,7 @@ def _cached_get_attn_backend(
|
|||||||
is_blocksparse: bool = False,
|
is_blocksparse: bool = False,
|
||||||
use_v1: bool = False,
|
use_v1: bool = False,
|
||||||
use_mla: bool = False,
|
use_mla: bool = False,
|
||||||
) -> Type[AttentionBackend]:
|
) -> type[AttentionBackend]:
|
||||||
if is_blocksparse:
|
if is_blocksparse:
|
||||||
logger.info("Using BlocksparseFlashAttention backend.")
|
logger.info("Using BlocksparseFlashAttention backend.")
|
||||||
from vllm.attention.backends.blocksparse_attn import (
|
from vllm.attention.backends.blocksparse_attn import (
|
||||||
|
|||||||
@ -2319,7 +2319,7 @@ class SchedulerConfig:
|
|||||||
|
|
||||||
if self.max_num_batched_tokens > self.max_num_seqs * self.max_model_len:
|
if self.max_num_batched_tokens > self.max_num_seqs * self.max_model_len:
|
||||||
logger.warning(
|
logger.warning(
|
||||||
"max_num_batched_tokens (%d) exceeds max_num_seqs"
|
"max_num_batched_tokens (%d) exceeds max_num_seqs "
|
||||||
"* max_model_len (%d). This may lead to unexpected behavior.",
|
"* max_model_len (%d). This may lead to unexpected behavior.",
|
||||||
self.max_num_batched_tokens,
|
self.max_num_batched_tokens,
|
||||||
self.max_num_seqs * self.max_model_len)
|
self.max_num_seqs * self.max_model_len)
|
||||||
|
|||||||
@ -234,35 +234,44 @@ class CudaPlatformBase(Platform):
|
|||||||
return ("vllm.attention.backends."
|
return ("vllm.attention.backends."
|
||||||
"flashmla.FlashMLABackend")
|
"flashmla.FlashMLABackend")
|
||||||
if use_v1:
|
if use_v1:
|
||||||
|
FLASHINFER_V1 = "vllm.v1.attention.backends.flashinfer.FlashInferBackend" # noqa: E501
|
||||||
|
FLEX_ATTENTION_V1 = "vllm.v1.attention.backends.flex_attention.FlexAttentionBackend" # noqa: E501
|
||||||
|
TRITON_ATTN_VLLM_V1 = "vllm.v1.attention.backends.triton_attn.TritonAttentionBackend" # noqa: E501
|
||||||
|
FLASH_ATTN_V1 = "vllm.v1.attention.backends.flash_attn.FlashAttentionBackend" # noqa: E501
|
||||||
|
|
||||||
if selected_backend == _Backend.FLASHINFER:
|
if selected_backend == _Backend.FLASHINFER:
|
||||||
logger.info_once("Using FlashInfer backend on V1 engine.")
|
logger.info_once("Using FlashInfer backend on V1 engine.")
|
||||||
return "vllm.v1.attention.backends.flashinfer.FlashInferBackend"
|
return FLASHINFER_V1
|
||||||
elif selected_backend == _Backend.FLEX_ATTENTION:
|
elif selected_backend == _Backend.FLEX_ATTENTION:
|
||||||
logger.info("Using FlexAttenion backend on V1 engine.")
|
logger.info_once("Using FlexAttention backend on V1 engine.")
|
||||||
return "vllm.v1.attention.backends.flex_attention.FlexAttentionBackend" # noqa: E501
|
return FLEX_ATTENTION_V1
|
||||||
elif selected_backend == _Backend.TRITON_ATTN_VLLM_V1:
|
elif selected_backend == _Backend.TRITON_ATTN_VLLM_V1:
|
||||||
logger.info_once("Using Triton backend on V1 engine.")
|
logger.info_once("Using Triton backend on V1 engine.")
|
||||||
return ("vllm.v1.attention.backends."
|
return TRITON_ATTN_VLLM_V1
|
||||||
"triton_attn.TritonAttentionBackend")
|
|
||||||
elif selected_backend == _Backend.FLASH_ATTN:
|
elif selected_backend == _Backend.FLASH_ATTN:
|
||||||
logger.info_once("Using Flash Attention backend on V1 engine.")
|
logger.info_once("Using Flash Attention backend on V1 engine.")
|
||||||
return ("vllm.v1.attention.backends."
|
return FLASH_ATTN_V1
|
||||||
"flash_attn.FlashAttentionBackend")
|
|
||||||
|
from vllm.attention.selector import supports_head_size
|
||||||
|
|
||||||
# Default backends for V1 engine
|
# Default backends for V1 engine
|
||||||
# Prefer FlashInfer for Blackwell GPUs if installed
|
# FP32 is only supported by FlexAttention
|
||||||
if dtype not in (torch.float16, torch.bfloat16):
|
if dtype not in (torch.float16, torch.bfloat16):
|
||||||
logger.info_once(
|
logger.info_once(
|
||||||
f"Using FlexAttenion backend for {dtype} on V1 engine.")
|
"Using FlexAttention backend for %s on V1 engine.",
|
||||||
return "vllm.v1.attention.backends.flex_attention.FlexAttentionBackend" # noqa: E501
|
dtype,
|
||||||
if cls.is_device_capability(100):
|
)
|
||||||
|
return FLEX_ATTENTION_V1
|
||||||
|
|
||||||
|
# Prefer FlashInfer for Blackwell GPUs if installed
|
||||||
|
if cls.is_device_capability(100) and \
|
||||||
|
supports_head_size(FLASHINFER_V1, head_size):
|
||||||
try:
|
try:
|
||||||
import flashinfer # noqa: F401
|
import flashinfer # noqa: F401
|
||||||
logger.info_once(
|
logger.info_once(
|
||||||
"Using FlashInfer backend on V1 engine by default for "
|
"Using FlashInfer backend on V1 engine by default for "
|
||||||
"Blackwell (SM 10.0) GPUs.")
|
"Blackwell (SM 10.0) GPUs.")
|
||||||
return ("vllm.v1.attention.backends."
|
return FLASHINFER_V1
|
||||||
"flashinfer.FlashInferBackend")
|
|
||||||
except ImportError:
|
except ImportError:
|
||||||
logger.info_once(
|
logger.info_once(
|
||||||
"FlashInfer failed to import for V1 engine on "
|
"FlashInfer failed to import for V1 engine on "
|
||||||
@ -270,10 +279,13 @@ class CudaPlatformBase(Platform):
|
|||||||
"install FlashInfer for better performance.")
|
"install FlashInfer for better performance.")
|
||||||
pass
|
pass
|
||||||
# FlashAttention is the default for SM 8.0+ GPUs
|
# FlashAttention is the default for SM 8.0+ GPUs
|
||||||
if cls.has_device_capability(80):
|
if cls.has_device_capability(80) and \
|
||||||
|
supports_head_size(FLASH_ATTN_V1, head_size):
|
||||||
logger.info_once("Using Flash Attention backend on V1 engine.")
|
logger.info_once("Using Flash Attention backend on V1 engine.")
|
||||||
return ("vllm.v1.attention.backends."
|
return FLASH_ATTN_V1
|
||||||
"flash_attn.FlashAttentionBackend")
|
|
||||||
|
logger.info_once("Using FlexAttention backend on V1 engine.")
|
||||||
|
return FLEX_ATTENTION_V1
|
||||||
|
|
||||||
# Backends for V0 engine
|
# Backends for V0 engine
|
||||||
if selected_backend == _Backend.FLASHINFER:
|
if selected_backend == _Backend.FLASHINFER:
|
||||||
|
|||||||
@ -3,7 +3,8 @@
|
|||||||
import numpy as np
|
import numpy as np
|
||||||
import torch
|
import torch
|
||||||
|
|
||||||
from vllm.attention.backends.abstract import AttentionMetadata
|
from vllm.attention.backends.abstract import (AttentionBackend,
|
||||||
|
AttentionMetadata)
|
||||||
from vllm.attention.backends.torch_sdpa import (TorchSDPABackendImpl,
|
from vllm.attention.backends.torch_sdpa import (TorchSDPABackendImpl,
|
||||||
TorchSDPAMetadata)
|
TorchSDPAMetadata)
|
||||||
from vllm.attention.backends.utils import CommonAttentionState
|
from vllm.attention.backends.utils import CommonAttentionState
|
||||||
@ -17,9 +18,24 @@ from vllm.v1.worker.cpu_model_runner import CPUModelRunner
|
|||||||
from vllm.v1.worker.gpu_input_batch import InputBatch
|
from vllm.v1.worker.gpu_input_batch import InputBatch
|
||||||
|
|
||||||
|
|
||||||
class TorchSDPABackend:
|
class TorchSDPABackend(AttentionBackend):
|
||||||
accept_output_buffer: bool = False
|
accept_output_buffer: bool = False
|
||||||
|
|
||||||
|
@classmethod
|
||||||
|
def get_supported_head_sizes(cls) -> list[int]:
|
||||||
|
return PagedAttention.get_supported_head_sizes()
|
||||||
|
|
||||||
|
@classmethod
|
||||||
|
def validate_head_size(cls, head_size: int) -> None:
|
||||||
|
supported_head_sizes = cls.get_supported_head_sizes()
|
||||||
|
if head_size not in supported_head_sizes:
|
||||||
|
attn_type = cls.__name__.removesuffix("Backend")
|
||||||
|
raise ValueError(
|
||||||
|
f"Head size {head_size} is not supported by {attn_type}. "
|
||||||
|
f"Supported head sizes are: {supported_head_sizes}. "
|
||||||
|
"Set VLLM_ATTENTION_BACKEND=FLEX_ATTENTION to use "
|
||||||
|
"FlexAttention backend which supports all head sizes.")
|
||||||
|
|
||||||
@staticmethod
|
@staticmethod
|
||||||
def get_name() -> str:
|
def get_name() -> str:
|
||||||
return "TORCH_SDPA_VLLM_V1"
|
return "TORCH_SDPA_VLLM_V1"
|
||||||
|
|||||||
@ -44,10 +44,21 @@ class FlashAttentionBackend(AttentionBackend):
|
|||||||
|
|
||||||
accept_output_buffer: bool = True
|
accept_output_buffer: bool = True
|
||||||
|
|
||||||
@staticmethod
|
@classmethod
|
||||||
def get_supported_head_sizes() -> list[int]:
|
def get_supported_head_sizes(cls) -> list[int]:
|
||||||
return [32, 64, 96, 128, 160, 192, 224, 256]
|
return [32, 64, 96, 128, 160, 192, 224, 256]
|
||||||
|
|
||||||
|
@classmethod
|
||||||
|
def validate_head_size(cls, head_size: int) -> None:
|
||||||
|
supported_head_sizes = cls.get_supported_head_sizes()
|
||||||
|
if head_size not in supported_head_sizes:
|
||||||
|
attn_type = cls.__name__.removesuffix("Backend")
|
||||||
|
raise ValueError(
|
||||||
|
f"Head size {head_size} is not supported by {attn_type}. "
|
||||||
|
f"Supported head sizes are: {supported_head_sizes}. "
|
||||||
|
"Set VLLM_ATTENTION_BACKEND=FLEX_ATTENTION to use "
|
||||||
|
"FlexAttention backend which supports all head sizes.")
|
||||||
|
|
||||||
@staticmethod
|
@staticmethod
|
||||||
def get_name() -> str:
|
def get_name() -> str:
|
||||||
return "FLASH_ATTN_VLLM_V1"
|
return "FLASH_ATTN_VLLM_V1"
|
||||||
@ -416,12 +427,7 @@ class FlashAttentionImpl(AttentionImpl):
|
|||||||
|
|
||||||
self.num_queries_per_kv = self.num_heads // self.num_kv_heads
|
self.num_queries_per_kv = self.num_heads // self.num_kv_heads
|
||||||
|
|
||||||
support_head_sizes = FlashAttentionBackend.get_supported_head_sizes()
|
FlashAttentionBackend.validate_head_size(head_size)
|
||||||
if head_size not in support_head_sizes:
|
|
||||||
raise ValueError(
|
|
||||||
f"Head size {head_size} is not supported by FlashAttention. "
|
|
||||||
f"Supported head sizes are: {support_head_sizes}. "
|
|
||||||
"Set VLLM_USE_V1=0 to use another attention backend.")
|
|
||||||
|
|
||||||
if attn_type != AttentionType.DECODER:
|
if attn_type != AttentionType.DECODER:
|
||||||
raise NotImplementedError("Encoder self-attention and "
|
raise NotImplementedError("Encoder self-attention and "
|
||||||
|
|||||||
@ -38,10 +38,22 @@ class FlashInferBackend(AttentionBackend):
|
|||||||
|
|
||||||
accept_output_buffer: bool = True
|
accept_output_buffer: bool = True
|
||||||
|
|
||||||
@staticmethod
|
@classmethod
|
||||||
def get_supported_head_sizes() -> list[int]:
|
def get_supported_head_sizes(cls) -> list[int]:
|
||||||
|
# https://github.com/flashinfer-ai/flashinfer/blob/3d55c71a62052c590c130897d3a3db49b14fcc34/include/flashinfer/utils.cuh#L157
|
||||||
return [64, 128, 256]
|
return [64, 128, 256]
|
||||||
|
|
||||||
|
@classmethod
|
||||||
|
def validate_head_size(cls, head_size: int) -> None:
|
||||||
|
supported_head_sizes = cls.get_supported_head_sizes()
|
||||||
|
if head_size not in supported_head_sizes:
|
||||||
|
attn_type = cls.__name__.removesuffix("Backend")
|
||||||
|
raise ValueError(
|
||||||
|
f"Head size {head_size} is not supported by {attn_type}. "
|
||||||
|
f"Supported head sizes are: {supported_head_sizes}. "
|
||||||
|
"Set VLLM_ATTENTION_BACKEND=FLEX_ATTENTION to use "
|
||||||
|
"FlexAttention backend which supports all head sizes.")
|
||||||
|
|
||||||
@staticmethod
|
@staticmethod
|
||||||
def get_name() -> str:
|
def get_name() -> str:
|
||||||
return "FLASHINFER_VLLM_V1"
|
return "FLASHINFER_VLLM_V1"
|
||||||
@ -207,14 +219,8 @@ class FlashInferMetadata:
|
|||||||
return self.qo_indptr
|
return self.qo_indptr
|
||||||
|
|
||||||
def __post_init__(self):
|
def __post_init__(self):
|
||||||
# Refer to
|
if self.head_dim is not None:
|
||||||
# https://github.com/flashinfer-ai/flashinfer/blob/3d55c71a62052c590c130897d3a3db49b14fcc34/include/flashinfer/utils.cuh#L157
|
FlashInferBackend.validate_head_size(self.head_dim)
|
||||||
supported_head_sizes = FlashInferBackend.get_supported_head_sizes()
|
|
||||||
if self.head_dim is not None and self.head_dim \
|
|
||||||
not in supported_head_sizes:
|
|
||||||
raise ValueError(
|
|
||||||
f"Only {supported_head_sizes} are supported for head_dim,",
|
|
||||||
f" received {self.head_dim}.")
|
|
||||||
|
|
||||||
|
|
||||||
class FlashInferMetadataBuilder(AttentionMetadataBuilder[FlashInferMetadata]):
|
class FlashInferMetadataBuilder(AttentionMetadataBuilder[FlashInferMetadata]):
|
||||||
|
|||||||
@ -1,7 +1,7 @@
|
|||||||
# SPDX-License-Identifier: Apache-2.0
|
# SPDX-License-Identifier: Apache-2.0
|
||||||
# SPDX-FileCopyrightText: Copyright contributors to the vLLM project
|
# SPDX-FileCopyrightText: Copyright contributors to the vLLM project
|
||||||
"""Attention layer with FlashAttention."""
|
"""Attention layer with FlashAttention."""
|
||||||
|
from collections import defaultdict
|
||||||
from dataclasses import dataclass
|
from dataclasses import dataclass
|
||||||
from typing import TYPE_CHECKING, Any, Optional
|
from typing import TYPE_CHECKING, Any, Optional
|
||||||
|
|
||||||
@ -21,9 +21,6 @@ from vllm.v1.attention.backends.utils import (AttentionMetadataBuilder,
|
|||||||
from vllm.v1.kv_cache_interface import AttentionSpec
|
from vllm.v1.kv_cache_interface import AttentionSpec
|
||||||
from vllm.v1.worker.block_table import BlockTable
|
from vllm.v1.worker.block_table import BlockTable
|
||||||
|
|
||||||
if current_platform.is_cuda():
|
|
||||||
pass
|
|
||||||
|
|
||||||
logger = init_logger(__name__)
|
logger = init_logger(__name__)
|
||||||
|
|
||||||
if TYPE_CHECKING:
|
if TYPE_CHECKING:
|
||||||
@ -45,9 +42,9 @@ def _offsets_to_doc_ids_tensor(offsets: torch.Tensor) -> torch.Tensor:
|
|||||||
class FlexAttentionBackend(AttentionBackend):
|
class FlexAttentionBackend(AttentionBackend):
|
||||||
accept_output_buffer: bool = True
|
accept_output_buffer: bool = True
|
||||||
|
|
||||||
@staticmethod
|
@classmethod
|
||||||
def get_supported_head_sizes() -> list[int]:
|
def validate_head_size(cls, head_size: int) -> None:
|
||||||
return [16, 32, 64, 96, 128, 160, 192, 224, 256]
|
return # FlexAttention supports any head size
|
||||||
|
|
||||||
@staticmethod
|
@staticmethod
|
||||||
def get_name() -> str:
|
def get_name() -> str:
|
||||||
@ -384,12 +381,8 @@ class FlexAttentionImpl(AttentionImpl):
|
|||||||
raise NotImplementedError(
|
raise NotImplementedError(
|
||||||
"FlexAttention does not support kv sharing yet.")
|
"FlexAttention does not support kv sharing yet.")
|
||||||
|
|
||||||
support_head_sizes = FlexAttentionBackend.get_supported_head_sizes()
|
FlexAttentionBackend.validate_head_size(head_size)
|
||||||
if head_size not in support_head_sizes:
|
|
||||||
raise ValueError(
|
|
||||||
f"Head size {head_size} is not supported by FlashAttention. "
|
|
||||||
f"Supported head sizes are: {support_head_sizes}. "
|
|
||||||
"Set VLLM_USE_V1=0 to use another attention backend.")
|
|
||||||
if is_quantized_kv_cache(self.kv_cache_dtype):
|
if is_quantized_kv_cache(self.kv_cache_dtype):
|
||||||
raise NotImplementedError(
|
raise NotImplementedError(
|
||||||
"FlexAttention does not support quantized kv-cache. Yet")
|
"FlexAttention does not support quantized kv-cache. Yet")
|
||||||
@ -464,12 +457,20 @@ class FlexAttentionImpl(AttentionImpl):
|
|||||||
# Doesn't work for now -> constraint violation
|
# Doesn't work for now -> constraint violation
|
||||||
# torch._dynamo.try_mark_dynamic(query, 2)
|
# torch._dynamo.try_mark_dynamic(query, 2)
|
||||||
|
|
||||||
# default M=64, N=64 may run out of shared memory on
|
# default M=64, N=64 may run out of shared memory on some GPUs
|
||||||
# some GPUs with fp32, so we use smaller M and N.
|
# TODO: Explicit configs for each GPU?
|
||||||
extra_kernel_options = {
|
# Not sure how to calculate the shared memory requirement
|
||||||
"BLOCK_M": 32,
|
extra_kernel_options = defaultdict[str, int](lambda: 64)
|
||||||
"BLOCK_N": 32
|
if query.dtype == torch.float32:
|
||||||
} if query.dtype == torch.float32 else {}
|
extra_kernel_options["BLOCK_M"] //= 2
|
||||||
|
extra_kernel_options["BLOCK_N"] //= 2
|
||||||
|
if current_platform.is_cuda():
|
||||||
|
device_props = torch.cuda.get_device_properties()
|
||||||
|
max_shared_memory = device_props.shared_memory_per_block_optin
|
||||||
|
if max_shared_memory < 144 * 1024:
|
||||||
|
extra_kernel_options["BLOCK_M"] //= 2
|
||||||
|
extra_kernel_options["BLOCK_N"] //= 2
|
||||||
|
|
||||||
out = flex_attention_compiled(
|
out = flex_attention_compiled(
|
||||||
query,
|
query,
|
||||||
key_cache,
|
key_cache,
|
||||||
|
|||||||
@ -254,10 +254,21 @@ class MLACommonBackend(AttentionBackend):
|
|||||||
) -> tuple[int, ...]:
|
) -> tuple[int, ...]:
|
||||||
return (num_blocks, block_size, head_size)
|
return (num_blocks, block_size, head_size)
|
||||||
|
|
||||||
@staticmethod
|
@classmethod
|
||||||
def get_supported_head_sizes() -> list[int]:
|
def get_supported_head_sizes(cls) -> list[int]:
|
||||||
return [576]
|
return [576]
|
||||||
|
|
||||||
|
@classmethod
|
||||||
|
def validate_head_size(cls, head_size: int) -> None:
|
||||||
|
supported_head_sizes = cls.get_supported_head_sizes()
|
||||||
|
if head_size not in supported_head_sizes:
|
||||||
|
attn_type = cls.__name__.removesuffix("Backend")
|
||||||
|
raise ValueError(
|
||||||
|
f"Head size {head_size} is not supported by {attn_type}. "
|
||||||
|
f"Supported head sizes are: {supported_head_sizes}. "
|
||||||
|
"Set VLLM_ATTENTION_BACKEND=FLEX_ATTENTION to use "
|
||||||
|
"FlexAttention backend which supports all head sizes.")
|
||||||
|
|
||||||
|
|
||||||
@dataclass
|
@dataclass
|
||||||
class MLACommonPrefillMetadata:
|
class MLACommonPrefillMetadata:
|
||||||
@ -320,12 +331,8 @@ class MLACommonMetadata(Generic[D]):
|
|||||||
prefill: Optional[MLACommonPrefillMetadata] = None
|
prefill: Optional[MLACommonPrefillMetadata] = None
|
||||||
|
|
||||||
def __post_init__(self):
|
def __post_init__(self):
|
||||||
supported_head_sizes = MLACommonBackend.get_supported_head_sizes()
|
if self.head_dim is not None:
|
||||||
if self.head_dim is not None and self.head_dim \
|
MLACommonBackend.validate_head_size(self.head_dim)
|
||||||
not in supported_head_sizes:
|
|
||||||
raise ValueError(
|
|
||||||
f"Only {supported_head_sizes} are supported for head_dim,",
|
|
||||||
f"received {self.head_dim}.")
|
|
||||||
|
|
||||||
|
|
||||||
M = TypeVar("M", bound=MLACommonMetadata)
|
M = TypeVar("M", bound=MLACommonMetadata)
|
||||||
|
|||||||
@ -314,10 +314,21 @@ class AiterFlashAttentionBackend(AttentionBackend):
|
|||||||
|
|
||||||
accept_output_buffer: bool = True
|
accept_output_buffer: bool = True
|
||||||
|
|
||||||
@staticmethod
|
@classmethod
|
||||||
def get_supported_head_sizes() -> list[int]:
|
def get_supported_head_sizes(cls) -> list[int]:
|
||||||
return [32, 64, 96, 128, 160, 192, 224, 256]
|
return [32, 64, 96, 128, 160, 192, 224, 256]
|
||||||
|
|
||||||
|
@classmethod
|
||||||
|
def validate_head_size(cls, head_size: int) -> None:
|
||||||
|
supported_head_sizes = cls.get_supported_head_sizes()
|
||||||
|
if head_size not in supported_head_sizes:
|
||||||
|
attn_type = cls.__name__.removesuffix("Backend")
|
||||||
|
raise ValueError(
|
||||||
|
f"Head size {head_size} is not supported by {attn_type}. "
|
||||||
|
f"Supported head sizes are: {supported_head_sizes}. "
|
||||||
|
"Set VLLM_ATTENTION_BACKEND=FLEX_ATTENTION to use "
|
||||||
|
"FlexAttention backend which supports all head sizes.")
|
||||||
|
|
||||||
@staticmethod
|
@staticmethod
|
||||||
def get_name() -> str:
|
def get_name() -> str:
|
||||||
return "FLASH_ATTN_VLLM_V1"
|
return "FLASH_ATTN_VLLM_V1"
|
||||||
@ -428,14 +439,7 @@ class AiterFlashAttentionImpl(AttentionImpl):
|
|||||||
assert self.num_heads % self.num_kv_heads == 0
|
assert self.num_heads % self.num_kv_heads == 0
|
||||||
self.num_queries_per_kv = self.num_heads // self.num_kv_heads
|
self.num_queries_per_kv = self.num_heads // self.num_kv_heads
|
||||||
|
|
||||||
support_head_sizes = \
|
AiterFlashAttentionBackend.validate_head_size(head_size)
|
||||||
AiterFlashAttentionBackend.get_supported_head_sizes()
|
|
||||||
if head_size not in support_head_sizes:
|
|
||||||
raise ValueError(
|
|
||||||
f"Head size {head_size} is not supported by "
|
|
||||||
"AiterFlashAttention. "
|
|
||||||
f"Supported head sizes are: {support_head_sizes}. "
|
|
||||||
"Set VLLM_USE_V1=0 to use another attention backend.")
|
|
||||||
|
|
||||||
if attn_type != AttentionType.DECODER:
|
if attn_type != AttentionType.DECODER:
|
||||||
raise NotImplementedError("Encoder self-attention and "
|
raise NotImplementedError("Encoder self-attention and "
|
||||||
|
|||||||
@ -190,10 +190,21 @@ class TritonAttentionBackend(AttentionBackend):
|
|||||||
|
|
||||||
accept_output_buffer: bool = True
|
accept_output_buffer: bool = True
|
||||||
|
|
||||||
@staticmethod
|
@classmethod
|
||||||
def get_supported_head_sizes() -> list[int]:
|
def get_supported_head_sizes(cls) -> list[int]:
|
||||||
return [32, 64, 96, 128, 160, 192, 224, 256]
|
return [32, 64, 96, 128, 160, 192, 224, 256]
|
||||||
|
|
||||||
|
@classmethod
|
||||||
|
def validate_head_size(cls, head_size: int) -> None:
|
||||||
|
supported_head_sizes = cls.get_supported_head_sizes()
|
||||||
|
if head_size not in supported_head_sizes:
|
||||||
|
attn_type = cls.__name__.removesuffix("Backend")
|
||||||
|
raise ValueError(
|
||||||
|
f"Head size {head_size} is not supported by {attn_type}. "
|
||||||
|
f"Supported head sizes are: {supported_head_sizes}. "
|
||||||
|
"Set VLLM_ATTENTION_BACKEND=FLEX_ATTENTION to use "
|
||||||
|
"FlexAttention backend which supports all head sizes.")
|
||||||
|
|
||||||
@staticmethod
|
@staticmethod
|
||||||
def get_name() -> str:
|
def get_name() -> str:
|
||||||
return "TRITON_ATTN_VLLM_V1"
|
return "TRITON_ATTN_VLLM_V1"
|
||||||
@ -268,11 +279,7 @@ class TritonAttentionImpl(AttentionImpl):
|
|||||||
|
|
||||||
self.num_queries_per_kv = self.num_heads // self.num_kv_heads
|
self.num_queries_per_kv = self.num_heads // self.num_kv_heads
|
||||||
|
|
||||||
support_head_sizes = TritonAttentionBackend.get_supported_head_sizes()
|
TritonAttentionBackend.validate_head_size(head_size)
|
||||||
if head_size not in support_head_sizes:
|
|
||||||
raise ValueError(
|
|
||||||
f"Head size {head_size} is not supported by TritonAttention. "
|
|
||||||
f"Supported head sizes are: {support_head_sizes}.")
|
|
||||||
|
|
||||||
if attn_type != AttentionType.DECODER:
|
if attn_type != AttentionType.DECODER:
|
||||||
raise NotImplementedError("Encoder self-attention and "
|
raise NotImplementedError("Encoder self-attention and "
|
||||||
|
|||||||
@ -12,8 +12,8 @@ from vllm.logger import init_logger
|
|||||||
from vllm.model_executor.model_loader import get_model
|
from vllm.model_executor.model_loader import get_model
|
||||||
from vllm.model_executor.models import supports_multimodal
|
from vllm.model_executor.models import supports_multimodal
|
||||||
from vllm.model_executor.models.llama_eagle3 import Eagle3LlamaForCausalLM
|
from vllm.model_executor.models.llama_eagle3 import Eagle3LlamaForCausalLM
|
||||||
from vllm.v1.attention.backends.flash_attn import (CommonAttentionMetadata,
|
from vllm.v1.attention.backends.flash_attn import FlashAttentionMetadata
|
||||||
FlashAttentionMetadata)
|
from vllm.v1.attention.backends.utils import CommonAttentionMetadata
|
||||||
from vllm.v1.kv_cache_interface import KVCacheConfig
|
from vllm.v1.kv_cache_interface import KVCacheConfig
|
||||||
from vllm.v1.sample.metadata import SamplingMetadata
|
from vllm.v1.sample.metadata import SamplingMetadata
|
||||||
from vllm.v1.spec_decode.utils import prepare_eagle_input_kernel
|
from vllm.v1.spec_decode.utils import prepare_eagle_input_kernel
|
||||||
|
|||||||
Loading…
x
Reference in New Issue
Block a user