[Attention][UX][1/N] Add AttentionConfig and change attention env vars to CLI arguments (#26315)

Signed-off-by: Matthew Bonanni <mbonanni@redhat.com>
Signed-off-by: Matthew Bonanni <mbonanni001@gmail.com>
Co-authored-by: gemini-code-assist[bot] <176961590+gemini-code-assist[bot]@users.noreply.github.com>
Co-authored-by: Harry Mellor <19981378+hmellor@users.noreply.github.com>
Co-authored-by: Lucas Wilkinson <LucasWilkinson@users.noreply.github.com>
This commit is contained in:
Matthew Bonanni 2025-12-05 12:48:43 -05:00 committed by GitHub
parent dff0a2b394
commit 66e674cdd5
No known key found for this signature in database
GPG Key ID: B5690EEEBB952194
22 changed files with 367 additions and 325 deletions

View File

@ -12,13 +12,13 @@ from vllm._custom_ops import cutlass_scaled_fp4_mm, scaled_fp4_quant
from vllm.attention.backends.abstract import AttentionMetadata
from vllm.attention.backends.registry import AttentionBackendEnum
from vllm.attention.layer import Attention
from vllm.attention.selector import global_force_attn_backend_context_manager
from vllm.compilation.fusion_attn import ATTN_OP, AttnFusionPass
from vllm.compilation.fx_utils import find_op_nodes
from vllm.compilation.matcher_utils import QUANT_OPS
from vllm.compilation.noop_elimination import NoOpEliminationPass
from vllm.compilation.post_cleanup import PostCleanupPass
from vllm.config import (
AttentionConfig,
CacheConfig,
CompilationConfig,
CompilationMode,
@ -335,6 +335,7 @@ def test_attention_quant_pattern(
custom_ops=custom_ops_list,
),
cache_config=CacheConfig(cache_dtype="fp8"),
attention_config=AttentionConfig(backend=backend),
)
# Create test inputs
@ -352,7 +353,6 @@ def test_attention_quant_pattern(
with (
set_current_vllm_config(vllm_config_unfused),
set_forward_context(attn_metadata=None, vllm_config=vllm_config_unfused),
global_force_attn_backend_context_manager(backend),
):
model_unfused = model_class(
num_qo_heads=num_qo_heads,
@ -378,7 +378,6 @@ def test_attention_quant_pattern(
with (
set_current_vllm_config(vllm_config),
set_forward_context(attn_metadata=None, vllm_config=vllm_config),
global_force_attn_backend_context_manager(backend),
):
model_fused = model_class(
num_qo_heads=num_qo_heads,

View File

@ -1151,13 +1151,29 @@ def test_register_kv_caches(dist_init, attn_backend, monkeypatch):
}
# Store tensor info for validation
expected_tensor_size = shared_tensor[0].element_size() * shared_tensor[0].numel()
expected_base_addrs = [
shared_tensor[0].data_ptr(),
shared_tensor[1].data_ptr(),
unique_tensor[0].data_ptr(),
unique_tensor[1].data_ptr(),
]
test_shape = backend_cls.get_kv_cache_shape(
num_blocks=1, block_size=16, num_kv_heads=1, head_size=1
)
is_blocks_first = len(test_shape) == 5 and test_shape[0] == 1
if is_blocks_first:
expected_tensor_size = shared_tensor.element_size() * shared_tensor.numel()
expected_base_addrs = [
shared_tensor.data_ptr(),
unique_tensor.data_ptr(),
]
expected_num_entries = 2
else:
expected_tensor_size = (
shared_tensor[0].element_size() * shared_tensor[0].numel()
)
expected_base_addrs = [
shared_tensor[0].data_ptr(),
shared_tensor[1].data_ptr(),
unique_tensor[0].data_ptr(),
unique_tensor[1].data_ptr(),
]
expected_num_entries = 4
with (
patch(
@ -1192,7 +1208,7 @@ def test_register_kv_caches(dist_init, attn_backend, monkeypatch):
# Verify get_reg_descs was called with caches_data
assert mock_wrapper_instance.get_reg_descs.called
caches_data, _ = mock_wrapper_instance.get_reg_descs.call_args[0]
assert len(caches_data) == 4
assert len(caches_data) == expected_num_entries
for i, cache_entry in enumerate(caches_data):
base_addr, size, _tp_rank, _ = cache_entry
@ -1214,7 +1230,12 @@ def test_register_kv_caches(dist_init, attn_backend, monkeypatch):
f"Expected {expected_blocks_count} blocks, got {len(blocks_data)}"
)
expected_block_len = expected_tensor_size // 2
num_blocks = 2
if is_blocks_first:
expected_block_len = expected_tensor_size // num_blocks // 2
else:
expected_block_len = expected_tensor_size // num_blocks
for i, block_entry in enumerate(blocks_data):
block_start_addr, block_len, tp_rank = block_entry
assert block_len == expected_block_len, (

View File

@ -6,8 +6,10 @@ import pytest
import torch
from vllm.attention.backends.abstract import MultipleOf
from vllm.attention.backends.registry import AttentionBackendEnum
from vllm.attention.layer import Attention
from vllm.config import (
AttentionConfig,
CacheConfig,
ModelConfig,
ParallelConfig,
@ -765,7 +767,7 @@ def test_init_kv_cache_with_kv_sharing_valid():
current_platform.is_rocm(),
reason="Attention backend FLASHINFER is not supported on ROCm.",
)
def test_hybrid_attention_mamba_tensor_shapes(monkeypatch):
def test_hybrid_attention_mamba_tensor_shapes():
"""
The GPU model runner creates different views into the
KVCacheTensors for the attention and mamba layers
@ -806,11 +808,13 @@ def test_hybrid_attention_mamba_tensor_shapes(monkeypatch):
cache_dtype="auto",
)
parallel_config = ParallelConfig()
attention_config = AttentionConfig(backend=AttentionBackendEnum.FLASHINFER)
vllm_config = VllmConfig(
model_config=model_config,
cache_config=cache_config,
scheduler_config=scheduler_config,
parallel_config=parallel_config,
attention_config=attention_config,
)
layer_0 = "model.layers.0.self_attn.attn"
@ -820,8 +824,7 @@ def test_hybrid_attention_mamba_tensor_shapes(monkeypatch):
layer_4 = "model.layers.4.mixer"
layer_5 = "model.layers.5.mixer"
with set_current_vllm_config(vllm_config), monkeypatch.context() as m:
m.setenv("VLLM_ATTENTION_BACKEND", "FLASHINFER")
with set_current_vllm_config(vllm_config):
hf_config = vllm_config.model_config.hf_config
fwd_context = {}
for key in [layer_0, layer_1]:
@ -851,10 +854,7 @@ def test_hybrid_attention_mamba_tensor_shapes(monkeypatch):
)
# suppress var not used error
assert fwd_context is not None
vllm_ctx = vllm_config.compilation_config.static_forward_context
with monkeypatch.context() as m:
m.setenv("VLLM_ATTENTION_BACKEND", "FLASHINFER")
vllm_ctx = vllm_config.compilation_config.static_forward_context
runner = GPUModelRunner(vllm_config, DEVICE)
kv_cache_spec = runner.get_kv_cache_spec()
@ -865,94 +865,94 @@ def test_hybrid_attention_mamba_tensor_shapes(monkeypatch):
)[0]
runner.initialize_kv_cache(kv_cache_config)
# random partition of blocks
# blocks0 will be assigned to attention layers
# blocks1 will be assigned to mamba layers
num_blocks = kv_cache_config.num_blocks
ind = np.arange(num_blocks)
np.random.shuffle(ind)
blocks0, blocks1 = ind[: (num_blocks // 2)], ind[(num_blocks // 2) :]
# random partition of blocks
# blocks0 will be assigned to attention layers
# blocks1 will be assigned to mamba layers
num_blocks = kv_cache_config.num_blocks
ind = np.arange(num_blocks)
np.random.shuffle(ind)
blocks0, blocks1 = ind[: (num_blocks // 2)], ind[(num_blocks // 2) :]
attn_shape = vllm_ctx[layer_0].kv_cache[0].shape
conv_shape = vllm_ctx[layer_2].kv_cache[0][0].shape
ssm_shape = vllm_ctx[layer_2].kv_cache[0][1].shape
attn_shape = vllm_ctx[layer_0].kv_cache[0].shape
conv_shape = vllm_ctx[layer_2].kv_cache[0][0].shape
ssm_shape = vllm_ctx[layer_2].kv_cache[0][1].shape
# assert we are using FlashInfer
assert attn_shape[0] % num_blocks == 0
block_split_ratio = attn_shape[0] // num_blocks
# assert we are using FlashInfer
assert attn_shape[0] % num_blocks == 0
block_split_ratio = attn_shape[0] // num_blocks
# use small blocks for testing to avoid memory issues
test_block_size = min(2, len(blocks0), len(blocks1))
# use small blocks for testing to avoid memory issues
test_block_size = min(2, len(blocks0), len(blocks1))
# use non-overlapping blocks to avoid data contamination
# Split kernel blocks: first half for attention, second half for mamba
mid_point = num_blocks // 2
# use non-overlapping blocks to avoid data contamination
# Split kernel blocks: first half for attention, second half for mamba
mid_point = num_blocks // 2
# attention uses kernel blocks from first half (mapped to logical blocks)
kv_blocks_for_attention = np.array([0, 1])[:test_block_size]
# attention uses kernel blocks from first half (mapped to logical blocks)
kv_blocks_for_attention = np.array([0, 1])[:test_block_size]
# mamba uses kernel blocks from second half
kv_blocks_for_mamba = np.array([mid_point, mid_point + 1])[:test_block_size]
# mamba uses kernel blocks from second half
kv_blocks_for_mamba = np.array([mid_point, mid_point + 1])[:test_block_size]
# create small constant tensors for testing with corrected shapes
# attention: [block_size, ...] starting from dimension 2
attn_constant_shape = attn_shape[2:]
conv_constant_shape = conv_shape[1:]
ssm_constant_shape = ssm_shape[1:]
# create small constant tensors for testing with corrected shapes
# attention: [block_size, ...] starting from dimension 2
attn_constant_shape = attn_shape[2:]
conv_constant_shape = conv_shape[1:]
ssm_constant_shape = ssm_shape[1:]
attn_blocks_constant = torch.full(
(test_block_size, *attn_constant_shape), device=DEVICE, fill_value=3.33
)
conv_blocks_constant = torch.full(
(test_block_size, *conv_constant_shape), device=DEVICE, fill_value=6.66
)
ssm_blocks_constant = torch.full(
(test_block_size, *ssm_constant_shape), device=DEVICE, fill_value=9.99
)
attn_blocks_constant = torch.full(
(test_block_size, *attn_constant_shape), device=DEVICE, fill_value=3.33
)
conv_blocks_constant = torch.full(
(test_block_size, *conv_constant_shape), device=DEVICE, fill_value=6.66
)
ssm_blocks_constant = torch.full(
(test_block_size, *ssm_constant_shape), device=DEVICE, fill_value=9.99
)
# Fill attention blocks with constants using kv block indices
kernel_blocks_for_attention = kv_blocks_for_attention * block_split_ratio
# Fill attention blocks with constants using kv block indices
kernel_blocks_for_attention = kv_blocks_for_attention * block_split_ratio
for layer in [layer_0, layer_1]:
# attention: kv_cache[0][kernel_block_idx, kv_idx, ...]
for i, kernel_block in enumerate(kernel_blocks_for_attention):
vllm_ctx[layer].kv_cache[0][kernel_block, :] = attn_blocks_constant[i]
for layer in [layer_0, layer_1]:
# attention: kv_cache[0][kernel_block_idx, kv_idx, ...]
for i, kernel_block in enumerate(kernel_blocks_for_attention):
vllm_ctx[layer].kv_cache[0][kernel_block, :] = attn_blocks_constant[i]
# fill mamba blocks with constants using kernel block indices
for layer in [layer_2, layer_3, layer_4, layer_5]:
# mamba: kv_cache[0][component][kernel_block_idx, ...]
for i, kv_block in enumerate(kv_blocks_for_mamba):
vllm_ctx[layer].kv_cache[0][0][kv_block, :] = conv_blocks_constant[i]
vllm_ctx[layer].kv_cache[0][1][kv_block, :] = ssm_blocks_constant[i]
# fill mamba blocks with constants using kernel block indices
for layer in [layer_2, layer_3, layer_4, layer_5]:
# mamba: kv_cache[0][component][kernel_block_idx, ...]
for i, kv_block in enumerate(kv_blocks_for_mamba):
vllm_ctx[layer].kv_cache[0][0][kv_block, :] = conv_blocks_constant[i]
vllm_ctx[layer].kv_cache[0][1][kv_block, :] = ssm_blocks_constant[i]
# verify attention and mamba contents are correct
for layer in [layer_0, layer_1]:
for i, kernel_block in enumerate(kernel_blocks_for_attention):
actual_kv = vllm_ctx[layer].kv_cache[0][kernel_block, :]
expected = attn_blocks_constant[i]
# verify attention and mamba contents are correct
for layer in [layer_0, layer_1]:
for i, kernel_block in enumerate(kernel_blocks_for_attention):
actual_kv = vllm_ctx[layer].kv_cache[0][kernel_block, :]
expected = attn_blocks_constant[i]
# Check K and V separately
assert torch.equal(actual_kv[0], expected)
assert torch.equal(actual_kv[1], expected)
# Check K and V separately
assert torch.equal(actual_kv[0], expected)
assert torch.equal(actual_kv[1], expected)
for layer in [layer_2, layer_3, layer_4, layer_5]:
for i, kv_block in enumerate(kv_blocks_for_mamba):
actual_conv = vllm_ctx[layer].kv_cache[0][0][kv_block, :]
actual_ssm = vllm_ctx[layer].kv_cache[0][1][kv_block, :]
expected_conv = conv_blocks_constant[i]
expected_ssm = ssm_blocks_constant[i]
for layer in [layer_2, layer_3, layer_4, layer_5]:
for i, kv_block in enumerate(kv_blocks_for_mamba):
actual_conv = vllm_ctx[layer].kv_cache[0][0][kv_block, :]
actual_ssm = vllm_ctx[layer].kv_cache[0][1][kv_block, :]
expected_conv = conv_blocks_constant[i]
expected_ssm = ssm_blocks_constant[i]
assert torch.equal(actual_conv, expected_conv)
assert torch.equal(actual_ssm, expected_ssm)
assert torch.equal(actual_conv, expected_conv)
assert torch.equal(actual_ssm, expected_ssm)
for layer in [layer_2, layer_3, layer_4, layer_5]:
for i, kv_block in enumerate(kv_blocks_for_mamba):
actual_conv = vllm_ctx[layer].kv_cache[0][0][kv_block, :]
actual_ssm = vllm_ctx[layer].kv_cache[0][1][kv_block, :]
expected_conv = conv_blocks_constant[i]
expected_ssm = ssm_blocks_constant[i]
assert torch.equal(actual_conv, expected_conv)
assert torch.equal(actual_ssm, expected_ssm)
for layer in [layer_2, layer_3, layer_4, layer_5]:
for i, kv_block in enumerate(kv_blocks_for_mamba):
actual_conv = vllm_ctx[layer].kv_cache[0][0][kv_block, :]
actual_ssm = vllm_ctx[layer].kv_cache[0][1][kv_block, :]
expected_conv = conv_blocks_constant[i]
expected_ssm = ssm_blocks_constant[i]
assert torch.equal(actual_conv, expected_conv)
assert torch.equal(actual_ssm, expected_ssm)
def test_hybrid_block_table_initialization():

View File

@ -289,6 +289,16 @@ class AttentionImpl(ABC, Generic[T]):
# even if they can return lse (for efficiency reasons)
need_to_return_lse_for_decode: bool = False
# Whether this attention implementation supports pre-quantized query input.
# When True, the attention layer will quantize queries before passing them
# to this backend, allowing torch.compile to fuse the quantization with
# previous operations. This is typically supported when using FP8 KV cache
# with compatible attention kernels (e.g., TRT-LLM).
# Subclasses should set this in __init__.
# TODO add support to more backends:
# https://github.com/vllm-project/vllm/issues/25584
supports_quant_query_input: bool = False
dcp_world_size: int
dcp_rank: int
@ -368,22 +378,6 @@ class AttentionImpl(ABC, Generic[T]):
"""
return False
def supports_quant_query_input(self) -> bool:
"""
Check if this attention implementation supports pre-quantized query input.
When True, the attention layer will quantize queries before passing them
to this backend, allowing torch.compile to fuse the quantization with
previous operations. This is typically supported when using FP8 KV cache
with compatible attention kernels (e.g., TRT-LLM).
TODO add support to more backends:
https://github.com/vllm-project/vllm/issues/25584
Returns:
bool: True if the implementation can accept pre-quantized queries.
"""
return False
def process_weights_after_loading(self, act_dtype: torch.dtype):
pass

View File

@ -303,7 +303,7 @@ class Attention(nn.Module, AttentionLayerBase):
self.query_quant = None
if (
self.kv_cache_dtype.startswith("fp8")
and self.impl.supports_quant_query_input()
and self.impl.supports_quant_query_input
):
self.query_quant = QuantFP8(static=True, group_shape=GroupShape.PER_TENSOR)
@ -338,7 +338,7 @@ class Attention(nn.Module, AttentionLayerBase):
assert self.kv_cache_dtype in {"fp8", "fp8_e4m3"}
# check if query quantization is supported
if self.impl.supports_quant_query_input():
if self.impl.supports_quant_query_input:
query, _ = self.query_quant(query, self._q_scale)
if self.use_output:

View File

@ -2,19 +2,14 @@
# SPDX-FileCopyrightText: Copyright contributors to the vLLM project
import inspect
import os
from collections.abc import Generator
from contextlib import contextmanager
from functools import cache
from typing import cast, get_args
import torch
import vllm.envs as envs
from vllm.attention.backends.abstract import AttentionBackend
from vllm.attention.backends.registry import (
MAMBA_TYPE_TO_BACKEND_MAP,
AttentionBackendEnum,
MambaAttentionBackendEnum,
)
from vllm.config.cache import CacheDType
@ -24,60 +19,6 @@ from vllm.utils.import_utils import resolve_obj_by_qualname
logger = init_logger(__name__)
def get_env_variable_attn_backend() -> AttentionBackendEnum | None:
"""
Get the backend override specified by the vLLM attention
backend environment variable, if one is specified.
Returns:
* AttentionBackendEnum value if an override is specified
* None otherwise
"""
backend_name = os.environ.get("VLLM_ATTENTION_BACKEND")
if backend_name is None:
return None
if backend_name == "XFORMERS":
raise ValueError(
"Attention backend 'XFORMERS' has been removed (See PR #29262 for "
"details). Please select a supported attention backend."
)
return AttentionBackendEnum[backend_name]
# Global state allows a particular choice of backend
# to be forced, overriding the logic which auto-selects
# a backend based on system & workload configuration
# (default behavior if this variable is None)
#
# THIS SELECTION TAKES PRECEDENCE OVER THE
# VLLM_ATTENTION_BACKEND ENVIRONMENT VARIABLE
forced_attn_backend: AttentionBackendEnum | None = None
def global_force_attn_backend(attn_backend: AttentionBackendEnum | None) -> None:
"""
Force all attention operations to use a specified backend.
Passing `None` for the argument re-enables automatic
backend selection.,
Arguments:
* attn_backend: backend selection (None to revert to auto)
"""
global forced_attn_backend
forced_attn_backend = attn_backend
def get_global_forced_attn_backend() -> AttentionBackendEnum | None:
"""
Get the currently-forced choice of attention backend,
or None if auto-selection is currently enabled.
"""
return forced_attn_backend
def get_attn_backend(
head_size: int,
dtype: torch.dtype,
@ -97,7 +38,13 @@ def get_attn_backend(
f"Valid values are: {valid_cache_dtypes}"
)
from vllm.config import get_current_vllm_config
vllm_config = get_current_vllm_config()
backend_enum = vllm_config.attention_config.backend
return _cached_get_attn_backend(
backend=backend_enum,
head_size=head_size,
dtype=dtype,
kv_cache_dtype=cast(CacheDType | None, kv_cache_dtype),
@ -111,6 +58,7 @@ def get_attn_backend(
@cache
def _cached_get_attn_backend(
backend,
head_size: int,
dtype: torch.dtype,
kv_cache_dtype: CacheDType | None,
@ -120,39 +68,6 @@ def _cached_get_attn_backend(
use_sparse: bool = False,
attn_type: str | None = None,
) -> type[AttentionBackend]:
# Check whether a particular choice of backend was
# previously forced.
#
# THIS SELECTION OVERRIDES THE VLLM_ATTENTION_BACKEND
# ENVIRONMENT VARIABLE.
selected_backend = None
backend_by_global_setting: AttentionBackendEnum | None = (
get_global_forced_attn_backend()
)
if backend_by_global_setting is not None:
selected_backend = backend_by_global_setting
else:
# Check the environment variable and override if specified
backend_by_env_var: str | None = envs.VLLM_ATTENTION_BACKEND
if backend_by_env_var is not None:
if backend_by_env_var.endswith("_VLLM_V1"):
logger.warning(
"The suffix '_VLLM_V1' in the environment variable "
"VLLM_ATTENTION_BACKEND is no longer necessary as "
"V0 backends have been deprecated. "
"Please remove this suffix from your "
"environment variable setting.",
)
backend_by_env_var = backend_by_env_var.removesuffix("_VLLM_V1")
try:
selected_backend = AttentionBackendEnum[backend_by_env_var]
except KeyError as e:
raise ValueError(
f"Invalid attention backend: '{backend_by_env_var}'. Valid "
f"backends are: {list(AttentionBackendEnum.__members__.keys())}"
) from e
# get device-specific attn_backend
from vllm.platforms import current_platform
sig = inspect.signature(current_platform.get_attn_backend_cls)
@ -163,7 +78,7 @@ def _cached_get_attn_backend(
"remove it from your plugin code."
)
attention_cls = current_platform.get_attn_backend_cls(
selected_backend,
backend,
head_size,
dtype,
kv_cache_dtype,
@ -176,7 +91,7 @@ def _cached_get_attn_backend(
)
else:
attention_cls = current_platform.get_attn_backend_cls(
selected_backend,
backend,
head_size,
dtype,
kv_cache_dtype,
@ -232,37 +147,3 @@ def _cached_get_mamba_attn_backend(
mamba_attn_backend = selected_backend.get_class()
return mamba_attn_backend
@contextmanager
def global_force_attn_backend_context_manager(
attn_backend: AttentionBackendEnum,
) -> Generator[None, None, None]:
"""
Globally force a vLLM attention backend override within a
context manager, reverting the global attention backend
override to its prior state upon exiting the context
manager.
Arguments:
* attn_backend: attention backend to force
Returns:
* Generator
"""
# Save the current state of the global backend override (if any)
original_value = get_global_forced_attn_backend()
# Globally force the new backend override
global_force_attn_backend(attn_backend)
# Yield control back to the enclosed code block
try:
yield
finally:
# Revert the original global backend override, if any
global_force_attn_backend(original_value)
_cached_get_attn_backend.cache_clear()

View File

@ -1,7 +1,6 @@
# SPDX-License-Identifier: Apache-2.0
# SPDX-FileCopyrightText: Copyright contributors to the vLLM project
from vllm import envs
from vllm.logger import init_logger
from vllm.platforms import current_platform
@ -49,10 +48,12 @@ def get_flash_attn_version(requires_alibi: bool = False) -> int | None:
3 if (device_capability.major == 9 and is_fa_version_supported(3)) else 2
)
# 2. override if passed by environment
if envs.VLLM_FLASH_ATTN_VERSION is not None:
assert envs.VLLM_FLASH_ATTN_VERSION in [2, 3]
fa_version = envs.VLLM_FLASH_ATTN_VERSION
# 2. override if passed by environment or config
from vllm.config import get_current_vllm_config
vllm_config = get_current_vllm_config()
if vllm_config.attention_config.flash_attn_version is not None:
fa_version = vllm_config.attention_config.flash_attn_version
# 3. fallback for unsupported combinations
if device_capability.major == 10 and fa_version == 3:

View File

@ -1,6 +1,7 @@
# SPDX-License-Identifier: Apache-2.0
# SPDX-FileCopyrightText: Copyright contributors to the vLLM project
from vllm.config.attention import AttentionConfig
from vllm.config.cache import CacheConfig
from vllm.config.compilation import (
CompilationConfig,
@ -46,6 +47,8 @@ from vllm.config.vllm import (
# __all__ should only contain classes and functions.
# Types and globals should be imported from their respective modules.
__all__ = [
# From vllm.config.attention
"AttentionConfig",
# From vllm.config.cache
"CacheConfig",
# From vllm.config.compilation

114
vllm/config/attention.py Normal file
View File

@ -0,0 +1,114 @@
# SPDX-License-Identifier: Apache-2.0
# SPDX-FileCopyrightText: Copyright contributors to the vLLM project
from typing import Any, Literal
from pydantic import field_validator
from pydantic.dataclasses import dataclass
from vllm.attention.backends.registry import AttentionBackendEnum
from vllm.config.utils import config
from vllm.logger import init_logger
logger = init_logger(__name__)
@config
@dataclass
class AttentionConfig:
"""Configuration for attention mechanisms in vLLM."""
backend: AttentionBackendEnum | None = None
"""Attention backend to use. If None, will be selected automatically."""
flash_attn_version: Literal[2, 3] | None = None
"""Force vllm to use a specific flash-attention version (2 or 3).
Only valid when using the flash-attention backend."""
use_prefill_decode_attention: bool = False
"""Use separate prefill and decode kernels for attention instead of
the unified triton kernel."""
flash_attn_max_num_splits_for_cuda_graph: int = 32
"""Flash Attention max number splits for cuda graph decode."""
use_cudnn_prefill: bool = False
"""Whether to use cudnn prefill."""
use_trtllm_ragged_deepseek_prefill: bool = False
"""Whether to use TRTLLM ragged deepseek prefill."""
use_trtllm_attention: bool | None = None
"""If set to True/False, use or don't use the TRTLLM attention backend
in flashinfer. If None, auto-detect the attention backend in flashinfer."""
disable_flashinfer_prefill: bool = False
"""Whether to disable flashinfer prefill."""
disable_flashinfer_q_quantization: bool = False
"""If set, when using fp8 kv, do not quantize Q to fp8."""
def compute_hash(self) -> str:
"""
Provide a hash that uniquely identifies all the configs
that affect the structure of the computation
graph from input ids/embeddings to the final hidden states,
excluding anything before input ids/embeddings and after
the final hidden states.
"""
from vllm.config.utils import get_hash_factors, hash_factors
ignored_factors: list[str] = []
factors = get_hash_factors(self, ignored_factors)
return hash_factors(factors)
@field_validator("backend", mode="before")
@classmethod
def validate_backend_before(cls, value: Any) -> Any:
"""Enable parsing of the `backend` enum type from string."""
if isinstance(value, str):
return AttentionBackendEnum[value.upper()]
return value
def _set_from_env_if_set(self, field_name: str, env_var_name: str) -> None:
"""Set field from env var if set, with deprecation warning."""
from vllm import envs
if envs.is_set(env_var_name):
value = getattr(envs, env_var_name)
if field_name == "backend":
value = self.validate_backend_before(value)
setattr(self, field_name, value)
logger.warning_once(
"Using %s environment variable is deprecated and will be removed in "
"v0.14.0 or v1.0.0, whichever is soonest. Please use "
"--attention-config.%s command line argument or "
"AttentionConfig(%s=...) config field instead.",
env_var_name,
field_name,
field_name,
)
def __post_init__(self) -> None:
self._set_from_env_if_set("backend", "VLLM_ATTENTION_BACKEND")
self._set_from_env_if_set("flash_attn_version", "VLLM_FLASH_ATTN_VERSION")
self._set_from_env_if_set(
"use_prefill_decode_attention", "VLLM_V1_USE_PREFILL_DECODE_ATTENTION"
)
self._set_from_env_if_set(
"flash_attn_max_num_splits_for_cuda_graph",
"VLLM_FLASH_ATTN_MAX_NUM_SPLITS_FOR_CUDA_GRAPH",
)
self._set_from_env_if_set("use_cudnn_prefill", "VLLM_USE_CUDNN_PREFILL")
self._set_from_env_if_set(
"use_trtllm_ragged_deepseek_prefill",
"VLLM_USE_TRTLLM_RAGGED_DEEPSEEK_PREFILL",
)
self._set_from_env_if_set("use_trtllm_attention", "VLLM_USE_TRTLLM_ATTENTION")
self._set_from_env_if_set(
"disable_flashinfer_prefill", "VLLM_DISABLE_FLASHINFER_PREFILL"
)
self._set_from_env_if_set(
"disable_flashinfer_q_quantization",
"VLLM_FLASHINFER_DISABLE_Q_QUANTIZATION",
)

View File

@ -4,7 +4,6 @@
import warnings
from collections.abc import Callable
from dataclasses import InitVar, field
from importlib.util import find_spec
from typing import TYPE_CHECKING, Any, Literal, cast, get_args
import torch
@ -467,18 +466,6 @@ class ModelConfig:
self.maybe_pull_model_tokenizer_for_runai(self.model, self.tokenizer)
if (
(backend := envs.VLLM_ATTENTION_BACKEND)
and backend == "FLASHINFER"
and find_spec("flashinfer") is None
):
raise ValueError(
"VLLM_ATTENTION_BACKEND is set to FLASHINFER, but flashinfer "
"module was not found. See "
"https://github.com/vllm-project/vllm/blob/main/docker/Dockerfile " # noqa: E501
"for instructions on how to install it."
)
from vllm.platforms import current_platform
if self.override_attention_dtype is not None and not current_platform.is_rocm():

View File

@ -27,6 +27,7 @@ from vllm.transformers_utils.runai_utils import is_runai_obj_uri
from vllm.utils import random_uuid
from vllm.utils.hashing import safe_hash
from .attention import AttentionConfig
from .cache import CacheConfig
from .compilation import CompilationConfig, CompilationMode, CUDAGraphMode
from .device import DeviceConfig
@ -192,6 +193,8 @@ class VllmConfig:
"""Device configuration."""
load_config: LoadConfig = Field(default_factory=LoadConfig)
"""Load configuration."""
attention_config: AttentionConfig = Field(default_factory=AttentionConfig)
"""Attention configuration."""
lora_config: LoRAConfig | None = None
"""LoRA configuration."""
speculative_config: SpeculativeConfig | None = None
@ -279,6 +282,10 @@ class VllmConfig:
vllm_factors.append(self.load_config.compute_hash())
else:
vllm_factors.append("None")
if self.attention_config:
vllm_factors.append(self.attention_config.compute_hash())
else:
vllm_factors.append("None")
if self.lora_config:
vllm_factors.append(self.lora_config.compute_hash())
else:

View File

@ -34,6 +34,7 @@ from typing_extensions import TypeIs
import vllm.envs as envs
from vllm.attention.backends.registry import AttentionBackendEnum
from vllm.config import (
AttentionConfig,
CacheConfig,
CompilationConfig,
ConfigType,
@ -527,6 +528,7 @@ class EngineArgs:
pooler_config: PoolerConfig | None = ModelConfig.pooler_config
compilation_config: CompilationConfig = get_field(VllmConfig, "compilation_config")
attention_config: AttentionConfig = get_field(VllmConfig, "attention_config")
worker_cls: str = ParallelConfig.worker_cls
worker_extension_cls: str = ParallelConfig.worker_extension_cls
@ -542,6 +544,7 @@ class EngineArgs:
)
model_impl: str = ModelConfig.model_impl
override_attention_dtype: str = ModelConfig.override_attention_dtype
attention_backend: AttentionBackendEnum | None = AttentionConfig.backend
calculate_kv_scales: bool = CacheConfig.calculate_kv_scales
mamba_cache_dtype: MambaDType = CacheConfig.mamba_cache_dtype
@ -580,6 +583,8 @@ class EngineArgs:
# CompilationConfig object
if isinstance(self.compilation_config, dict):
self.compilation_config = CompilationConfig(**self.compilation_config)
if isinstance(self.attention_config, dict):
self.attention_config = AttentionConfig(**self.attention_config)
if isinstance(self.eplb_config, dict):
self.eplb_config = EPLBConfig(**self.eplb_config)
# Setup plugins
@ -717,6 +722,16 @@ class EngineArgs:
"--pt-load-map-location", **load_kwargs["pt_load_map_location"]
)
# Attention arguments
attention_kwargs = get_kwargs(AttentionConfig)
attention_group = parser.add_argument_group(
title="AttentionConfig",
description=AttentionConfig.__doc__,
)
attention_group.add_argument(
"--attention-backend", **attention_kwargs["backend"]
)
# Structured outputs arguments
structured_outputs_kwargs = get_kwargs(StructuredOutputsConfig)
structured_outputs_group = parser.add_argument_group(
@ -1140,6 +1155,9 @@ class EngineArgs:
vllm_group.add_argument(
"--compilation-config", "-cc", **vllm_kwargs["compilation_config"]
)
vllm_group.add_argument(
"--attention-config", "-ac", **vllm_kwargs["attention_config"]
)
vllm_group.add_argument(
"--additional-config", **vllm_kwargs["additional_config"]
)
@ -1693,6 +1711,16 @@ class EngineArgs:
if model_config.quantization == "bitsandbytes":
self.quantization = self.load_format = "bitsandbytes"
# Attention config overrides
attention_config = copy.deepcopy(self.attention_config)
if self.attention_backend is not None:
if attention_config.backend is not None:
raise ValueError(
"attention_backend and attention_config.backend "
"are mutually exclusive"
)
attention_config.backend = self.attention_backend
load_config = self.create_load_config()
# Pass reasoning_parser into StructuredOutputsConfig
@ -1750,9 +1778,10 @@ class EngineArgs:
parallel_config=parallel_config,
scheduler_config=scheduler_config,
device_config=device_config,
load_config=load_config,
attention_config=attention_config,
lora_config=lora_config,
speculative_config=speculative_config,
load_config=load_config,
structured_outputs_config=self.structured_outputs_config,
observability_config=observability_config,
compilation_config=compilation_config,

View File

@ -4,7 +4,7 @@ from copy import deepcopy
from math import lcm
from typing import TYPE_CHECKING
import vllm.envs as envs
from vllm.attention.backends.registry import AttentionBackendEnum
from vllm.logger import init_logger
from vllm.model_executor.models import ModelRegistry
from vllm.platforms import current_platform
@ -331,6 +331,7 @@ class HybridAttentionMambaModelConfig(VerifyAndUpdateConfig):
# Enable FULL_AND_PIECEWISE by default
MambaModelConfig.verify_and_update_config(vllm_config)
attention_config = vllm_config.attention_config
cache_config = vllm_config.cache_config
model_config = vllm_config.model_config
parallel_config = vllm_config.parallel_config
@ -347,7 +348,9 @@ class HybridAttentionMambaModelConfig(VerifyAndUpdateConfig):
# * CUTLASS_MLA backend: kernel_block_size 128 alignment
# * Other MLA backends: kernel_block_size 64 alignment
if model_config.use_mla:
use_cutlass_mla = envs.VLLM_ATTENTION_BACKEND == "CUTLASS_MLA"
use_cutlass_mla = (
attention_config.backend == AttentionBackendEnum.CUTLASS_MLA
)
kernel_block_alignment_size = 128 if use_cutlass_mla else 64
attn_page_size_1_token = MLAAttentionSpec(
block_size=1,
@ -361,8 +364,8 @@ class HybridAttentionMambaModelConfig(VerifyAndUpdateConfig):
current_platform.is_device_capability(100)
and model_config.get_head_size() == 256
and (
envs.VLLM_ATTENTION_BACKEND is None
or envs.VLLM_ATTENTION_BACKEND == "FLASHINFER"
attention_config.backend is None
or attention_config.backend == AttentionBackendEnum.FLASHINFER
)
):
# https://github.com/flashinfer-ai/flashinfer/issues/1993 reports that`

View File

@ -11,7 +11,7 @@ import torch
from transformers import PretrainedConfig
from vllm.attention.backends.registry import AttentionBackendEnum
from vllm.config import VllmConfig
from vllm.config import VllmConfig, get_current_vllm_config
from vllm.distributed import (
get_tensor_model_parallel_rank,
get_tensor_model_parallel_world_size,
@ -91,10 +91,7 @@ def get_vit_attn_backend(
if attn_backend_override is not None:
return attn_backend_override
# Lazy import to avoid circular dependency
from vllm.attention.selector import get_env_variable_attn_backend
selected_backend: AttentionBackendEnum | None = get_env_variable_attn_backend()
selected_backend = get_current_vllm_config().attention_config.backend
if selected_backend is not None:
return selected_backend

View File

@ -14,7 +14,6 @@ from typing_extensions import ParamSpec
# import custom ops, trigger op registration
import vllm._C # noqa
import vllm.envs as envs
from vllm.attention.backends.abstract import AttentionType
from vllm.attention.backends.registry import AttentionBackendEnum
from vllm.logger import init_logger
@ -149,6 +148,8 @@ class CudaPlatformBase(Platform):
@classmethod
def check_and_update_config(cls, vllm_config: "VllmConfig") -> None:
from vllm.attention.backends.registry import AttentionBackendEnum
parallel_config = vllm_config.parallel_config
model_config = vllm_config.model_config
@ -171,7 +172,7 @@ class CudaPlatformBase(Platform):
and cache_config.block_size is not None
):
use_sparse = hasattr(vllm_config.model_config.hf_config, "index_topk")
# If `VLLM_ATTENTION_BACKEND` is not set and we are using MLA,
# If `--attention-config.backend` is not set and we are using MLA,
# then we default to FlashMLA backend for non-blackwell GPUs,
# else we default to CutlassMLA. For each case, we force the
# required block_size.
@ -179,23 +180,25 @@ class CudaPlatformBase(Platform):
use_cutlass_mla = False
use_flashinfer_mla = False
if envs.VLLM_ATTENTION_BACKEND is None:
if vllm_config.attention_config.backend is None:
# Default case
if cls.is_device_capability(100):
# Blackwell => Force CutlassMLA.
use_cutlass_mla = True
# TODO: This does not work, because the
# global_force_attn_backend_context_manager is not set.
# See vllm/attention/selector.py:_cached_get_attn_backend
envs.VLLM_ATTENTION_BACKEND = "CUTLASS_MLA"
# Set the backend in AttentionConfig so it's used during
# backend selection
vllm_config.attention_config.backend = (
AttentionBackendEnum.CUTLASS_MLA
)
else:
# Not Blackwell
use_flashmla = True
else:
# Forced case
use_flashmla = envs.VLLM_ATTENTION_BACKEND == "FLASHMLA"
use_cutlass_mla = envs.VLLM_ATTENTION_BACKEND == "CUTLASS_MLA"
use_flashinfer_mla = envs.VLLM_ATTENTION_BACKEND == "FLASHINFER_MLA"
backend = vllm_config.attention_config.backend
use_flashmla = backend == AttentionBackendEnum.FLASHMLA
use_cutlass_mla = backend == AttentionBackendEnum.CUTLASS_MLA
use_flashinfer_mla = backend == AttentionBackendEnum.FLASHINFER_MLA
from vllm.attention.ops.flashmla import is_flashmla_dense_supported

View File

@ -267,21 +267,16 @@ def supports_trtllm_attention() -> bool:
return current_platform.is_device_capability(100) and has_nvidia_artifactory()
@functools.cache
def _force_use_trtllm_attention(env_value: bool | None) -> bool | None:
"""Cache the env value for VLLM_USE_TRTLLM_ATTENTION"""
if env_value is not None:
logger.info_once("VLLM_USE_TRTLLM_ATTENTION is set to %s", env_value)
return env_value
def force_use_trtllm_attention() -> bool | None:
"""
Return `None` if VLLM_USE_TRTLLM_ATTENTION is not set,
Return `None` if --attention-config.use_trtllm_attention is not set,
return `True` if TRTLLM attention is forced to be used,
return `False` if TRTLLM attention is forced to be not used.
"""
return _force_use_trtllm_attention(envs.VLLM_USE_TRTLLM_ATTENTION)
from vllm.config import get_current_vllm_config
vllm_config = get_current_vllm_config()
return vllm_config.attention_config.use_trtllm_attention
def can_use_trtllm_attention(num_qo_heads: int, num_kv_heads: int) -> bool:
@ -307,7 +302,7 @@ def use_trtllm_attention(
"""Return `True` if TRTLLM attention is used."""
force_use_trtllm = force_use_trtllm_attention()
# Environment variable is set to 0 - respect it
# CLI argument is set to 0 - respect it
if force_use_trtllm is not None and not force_use_trtllm:
return False
@ -324,7 +319,7 @@ def use_trtllm_attention(
if force_use_trtllm:
logger.warning_once(
"TRTLLM attention is not supported on this platform, "
"but VLLM_USE_TRTLLM_ATTENTION is set to 1"
"but --attention-config.use_trtllm_attention is set to 1"
)
return False
@ -333,7 +328,8 @@ def use_trtllm_attention(
if force_use_trtllm:
logger.warning_once(
"TRTLLM attention is not supported for this combination of "
"query and key heads, but VLLM_USE_TRTLLM_ATTENTION is set to 1"
"query and key heads, but --attention-config.use_trtllm_attention is "
"set to 1"
)
return False
@ -354,7 +350,7 @@ def use_trtllm_attention(
return True
if force_use_trtllm is None:
# Environment variable not set - use auto-detection
# CLI argument not set - use auto-detection
if is_prefill:
# Prefill auto-detection
use_trtllm = kv_cache_dtype == "auto"
@ -367,8 +363,10 @@ def use_trtllm_attention(
logger.warning_once("Using TRTLLM decode attention (auto-detected).")
return use_trtllm
# Environment variable is set to 1 - respect it
logger.info_once("Using TRTLLM attention (VLLM_USE_TRTLLM_ATTENTION is set to 1)")
# CLI argument is set to 1 - respect it
logger.info_once(
"Using TRTLLM attention (--attention-config.use_trtllm_attention is set to 1)"
)
return True
@ -500,12 +498,6 @@ def flashinfer_scaled_fp8_mm(
return output
@functools.cache
def flashinfer_disable_q_quantization() -> bool:
"""Cache result which only depends on the environment"""
return envs.VLLM_FLASHINFER_DISABLE_Q_QUANTIZATION
__all__ = [
"has_flashinfer",
"flashinfer_trtllm_fp8_block_scale_moe",
@ -526,7 +518,6 @@ __all__ = [
"supports_trtllm_attention",
"can_use_trtllm_attention",
"use_trtllm_attention",
"flashinfer_disable_q_quantization",
"flashinfer_scaled_fp4_mm",
"flashinfer_scaled_fp8_mm",
]

View File

@ -8,7 +8,6 @@ from typing import ClassVar
import numpy as np
import torch
from vllm import envs
from vllm.attention.backends.abstract import (
AttentionBackend,
AttentionImpl,
@ -264,6 +263,7 @@ class FlashAttentionMetadataBuilder(AttentionMetadataBuilder[FlashAttentionMetad
self.parallel_config = vllm_config.parallel_config
self.cache_config = vllm_config.cache_config
self.compilation_config = vllm_config.compilation_config
self.attention_config = vllm_config.attention_config
self.num_heads_q = self.model_config.get_num_attention_heads(
self.parallel_config
@ -304,7 +304,9 @@ class FlashAttentionMetadataBuilder(AttentionMetadataBuilder[FlashAttentionMetad
# When using cuda graph, we need to set the upper bound of the
# number of splits so that large enough intermediate buffers are
# pre-allocated during capture.
self.max_num_splits = envs.VLLM_FLASH_ATTN_MAX_NUM_SPLITS_FOR_CUDA_GRAPH
self.max_num_splits = (
self.attention_config.flash_attn_max_num_splits_for_cuda_graph
)
# Sliding window size to be used with the AOT scheduler will be
# populated on first build() call.
@ -554,8 +556,7 @@ class FlashAttentionImpl(AttentionImpl):
"heads in the layer"
)
def supports_quant_query_input(self) -> bool:
return True
self.supports_quant_query_input = True
def forward(
self,

View File

@ -26,7 +26,7 @@ from vllm.attention.backends.abstract import (
)
from vllm.attention.ops.common import cp_lse_ag_out_rs
from vllm.attention.ops.merge_attn_states import merge_attn_states
from vllm.config import CUDAGraphMode, VllmConfig
from vllm.config import CUDAGraphMode, VllmConfig, get_current_vllm_config
from vllm.config.cache import CacheDType
from vllm.distributed.parallel_state import get_dcp_group
from vllm.logger import init_logger
@ -43,7 +43,6 @@ from vllm.platforms.interface import DeviceCapability
from vllm.triton_utils import tl, triton
from vllm.utils.flashinfer import (
can_use_trtllm_attention,
flashinfer_disable_q_quantization,
use_trtllm_attention,
)
from vllm.utils.math_utils import cdiv
@ -362,7 +361,8 @@ class FlashInferBackend(AttentionBackend):
supports_trtllm_attention,
)
# Respect explicit disable flag (e.g., VLLM_USE_TRTLLM_ATTENTION=0)
# Respect explicit disable flag (e.g.,
# --attention-config.use_trtllm_attention=0)
if force_use_trtllm_attention() is False:
return False
@ -500,11 +500,14 @@ class FlashInferMetadataBuilder(AttentionMetadataBuilder[FlashInferMetadata]):
self.kv_cache_dtype = self.kv_cache_spec.dtype
# Use model dtype as q dtype when TRTLLM attn is not supported, or
# VLLM_FLASHINFER_DISABLE_Q_QUANTIZATION is set to 1. Otherwise, try to
# use fp8 q if kv cache is fp8, and will fall back to model dtype
# --attention-config.disable_flashinfer_q_quantization is set to 1. Otherwise,
# try to use fp8 q if kv cache is fp8, and will fall back to model dtype
# if TRTLLM attention kernel is not used when building attn metadata
can_use_trtllm = can_use_trtllm_attention(self.num_qo_heads, self.num_kv_heads)
if can_use_trtllm and not flashinfer_disable_q_quantization():
if (
can_use_trtllm
and not vllm_config.attention_config.disable_flashinfer_q_quantization
):
self.q_data_type = self.kv_cache_dtype
else:
self.q_data_type = self.model_config.dtype
@ -1035,6 +1038,11 @@ class FlashInferImpl(AttentionImpl):
self.sinks = sinks
self.support_trtllm_attn = can_use_trtllm_attention(num_heads, num_kv_heads)
vllm_config = get_current_vllm_config()
self.supports_quant_query_input = (
self.support_trtllm_attn
and not vllm_config.attention_config.disable_flashinfer_q_quantization
)
self.bmm1_scale: float | None = None
self.bmm2_scale: float | None = None
self.o_sf_scale: float | None = None
@ -1046,12 +1054,6 @@ class FlashInferImpl(AttentionImpl):
and quant_key in (kFp8StaticTensorSym, kNvfp4Quant)
)
def supports_quant_query_input(self) -> bool:
if flashinfer_disable_q_quantization():
return False
return self.support_trtllm_attn
# FlashInfer requires attention sinks to be float32
def process_weights_after_loading(self, act_dtype: torch.dtype):
if self.sinks is not None and self.sinks.dtype != torch.float32:

View File

@ -438,19 +438,25 @@ A = TypeVar("A")
def use_flashinfer_prefill() -> bool:
# For blackwell default to flashinfer prefill if it's available since
# it is faster than FA2.
from vllm.config import get_current_vllm_config
vllm_config = get_current_vllm_config()
return (
not envs.VLLM_DISABLE_FLASHINFER_PREFILL
not vllm_config.attention_config.disable_flashinfer_prefill
and flashinfer_available
and not envs.VLLM_USE_CUDNN_PREFILL
and not envs.VLLM_USE_TRTLLM_RAGGED_DEEPSEEK_PREFILL
and not vllm_config.attention_config.use_cudnn_prefill
and not vllm_config.attention_config.use_trtllm_ragged_deepseek_prefill
and current_platform.is_device_capability(100)
)
def use_cudnn_prefill() -> bool:
from vllm.config import get_current_vllm_config
vllm_config = get_current_vllm_config()
return (
flashinfer_available
and envs.VLLM_USE_CUDNN_PREFILL
and vllm_config.attention_config.use_cudnn_prefill
and current_platform.is_device_capability(100)
and has_nvidia_artifactory()
)
@ -458,9 +464,12 @@ def use_cudnn_prefill() -> bool:
def use_trtllm_ragged_deepseek_prefill() -> bool:
"""Check if TRT-LLM ragged DeepSeek prefill should be used."""
from vllm.config import get_current_vllm_config
vllm_config = get_current_vllm_config()
return (
flashinfer_available
and envs.VLLM_USE_TRTLLM_RAGGED_DEEPSEEK_PREFILL
and vllm_config.attention_config.use_trtllm_ragged_deepseek_prefill
and current_platform.is_device_capability(100)
)

View File

@ -6,7 +6,6 @@ from typing import ClassVar
import torch
from vllm import envs
from vllm.attention.backends.abstract import (
AttentionLayer,
AttentionType,
@ -131,7 +130,9 @@ class FlashAttnMLAMetadataBuilder(MLACommonMetadataBuilder[FlashAttnMLAMetadata]
# When using cuda graph, we need to set the upper bound of the
# number of splits so that large enough intermediate buffers are
# pre-allocated during capture.
self.max_num_splits = envs.VLLM_FLASH_ATTN_MAX_NUM_SPLITS_FOR_CUDA_GRAPH
self.max_num_splits = (
vllm_config.attention_config.flash_attn_max_num_splits_for_cuda_graph
)
if vllm_is_batch_invariant():
self.max_num_splits = 1

View File

@ -165,7 +165,7 @@ class RocmAttentionBackend(AttentionBackend):
raise ValueError(
f"Head size {head_size} is not supported by {attn_type}. "
f"Supported head sizes are: {cls.get_supported_head_sizes()}. "
"Set VLLM_ATTENTION_BACKEND=FLEX_ATTENTION to use "
"Set --attention-config.backend=FLEX_ATTENTION to use "
"FlexAttention backend which supports all head sizes."
)

View File

@ -210,9 +210,6 @@ class TritonAttentionImpl(AttentionImpl):
def fused_output_quant_supported(self, quant_key: QuantKey):
return quant_key == kFp8StaticTensorSym
def supports_quant_query_input(self) -> bool:
return current_platform.is_cuda()
def __init__(
self,
num_heads: int,
@ -262,6 +259,8 @@ class TritonAttentionImpl(AttentionImpl):
f"num_heads: {num_heads}."
)
self.supports_quant_query_input = current_platform.is_cuda()
def forward(
self,
layer: torch.nn.Module,