[Attention][UX][1/N] Add AttentionConfig and change attention env vars to CLI arguments (#26315)

Signed-off-by: Matthew Bonanni <mbonanni@redhat.com>
Signed-off-by: Matthew Bonanni <mbonanni001@gmail.com>
Co-authored-by: gemini-code-assist[bot] <176961590+gemini-code-assist[bot]@users.noreply.github.com>
Co-authored-by: Harry Mellor <19981378+hmellor@users.noreply.github.com>
Co-authored-by: Lucas Wilkinson <LucasWilkinson@users.noreply.github.com>
This commit is contained in:
Matthew Bonanni 2025-12-05 12:48:43 -05:00 committed by GitHub
parent dff0a2b394
commit 66e674cdd5
No known key found for this signature in database
GPG Key ID: B5690EEEBB952194
22 changed files with 367 additions and 325 deletions

View File

@ -12,13 +12,13 @@ from vllm._custom_ops import cutlass_scaled_fp4_mm, scaled_fp4_quant
from vllm.attention.backends.abstract import AttentionMetadata from vllm.attention.backends.abstract import AttentionMetadata
from vllm.attention.backends.registry import AttentionBackendEnum from vllm.attention.backends.registry import AttentionBackendEnum
from vllm.attention.layer import Attention from vllm.attention.layer import Attention
from vllm.attention.selector import global_force_attn_backend_context_manager
from vllm.compilation.fusion_attn import ATTN_OP, AttnFusionPass from vllm.compilation.fusion_attn import ATTN_OP, AttnFusionPass
from vllm.compilation.fx_utils import find_op_nodes from vllm.compilation.fx_utils import find_op_nodes
from vllm.compilation.matcher_utils import QUANT_OPS from vllm.compilation.matcher_utils import QUANT_OPS
from vllm.compilation.noop_elimination import NoOpEliminationPass from vllm.compilation.noop_elimination import NoOpEliminationPass
from vllm.compilation.post_cleanup import PostCleanupPass from vllm.compilation.post_cleanup import PostCleanupPass
from vllm.config import ( from vllm.config import (
AttentionConfig,
CacheConfig, CacheConfig,
CompilationConfig, CompilationConfig,
CompilationMode, CompilationMode,
@ -335,6 +335,7 @@ def test_attention_quant_pattern(
custom_ops=custom_ops_list, custom_ops=custom_ops_list,
), ),
cache_config=CacheConfig(cache_dtype="fp8"), cache_config=CacheConfig(cache_dtype="fp8"),
attention_config=AttentionConfig(backend=backend),
) )
# Create test inputs # Create test inputs
@ -352,7 +353,6 @@ def test_attention_quant_pattern(
with ( with (
set_current_vllm_config(vllm_config_unfused), set_current_vllm_config(vllm_config_unfused),
set_forward_context(attn_metadata=None, vllm_config=vllm_config_unfused), set_forward_context(attn_metadata=None, vllm_config=vllm_config_unfused),
global_force_attn_backend_context_manager(backend),
): ):
model_unfused = model_class( model_unfused = model_class(
num_qo_heads=num_qo_heads, num_qo_heads=num_qo_heads,
@ -378,7 +378,6 @@ def test_attention_quant_pattern(
with ( with (
set_current_vllm_config(vllm_config), set_current_vllm_config(vllm_config),
set_forward_context(attn_metadata=None, vllm_config=vllm_config), set_forward_context(attn_metadata=None, vllm_config=vllm_config),
global_force_attn_backend_context_manager(backend),
): ):
model_fused = model_class( model_fused = model_class(
num_qo_heads=num_qo_heads, num_qo_heads=num_qo_heads,

View File

@ -1151,13 +1151,29 @@ def test_register_kv_caches(dist_init, attn_backend, monkeypatch):
} }
# Store tensor info for validation # Store tensor info for validation
expected_tensor_size = shared_tensor[0].element_size() * shared_tensor[0].numel() test_shape = backend_cls.get_kv_cache_shape(
expected_base_addrs = [ num_blocks=1, block_size=16, num_kv_heads=1, head_size=1
shared_tensor[0].data_ptr(), )
shared_tensor[1].data_ptr(), is_blocks_first = len(test_shape) == 5 and test_shape[0] == 1
unique_tensor[0].data_ptr(),
unique_tensor[1].data_ptr(), if is_blocks_first:
] expected_tensor_size = shared_tensor.element_size() * shared_tensor.numel()
expected_base_addrs = [
shared_tensor.data_ptr(),
unique_tensor.data_ptr(),
]
expected_num_entries = 2
else:
expected_tensor_size = (
shared_tensor[0].element_size() * shared_tensor[0].numel()
)
expected_base_addrs = [
shared_tensor[0].data_ptr(),
shared_tensor[1].data_ptr(),
unique_tensor[0].data_ptr(),
unique_tensor[1].data_ptr(),
]
expected_num_entries = 4
with ( with (
patch( patch(
@ -1192,7 +1208,7 @@ def test_register_kv_caches(dist_init, attn_backend, monkeypatch):
# Verify get_reg_descs was called with caches_data # Verify get_reg_descs was called with caches_data
assert mock_wrapper_instance.get_reg_descs.called assert mock_wrapper_instance.get_reg_descs.called
caches_data, _ = mock_wrapper_instance.get_reg_descs.call_args[0] caches_data, _ = mock_wrapper_instance.get_reg_descs.call_args[0]
assert len(caches_data) == 4 assert len(caches_data) == expected_num_entries
for i, cache_entry in enumerate(caches_data): for i, cache_entry in enumerate(caches_data):
base_addr, size, _tp_rank, _ = cache_entry base_addr, size, _tp_rank, _ = cache_entry
@ -1214,7 +1230,12 @@ def test_register_kv_caches(dist_init, attn_backend, monkeypatch):
f"Expected {expected_blocks_count} blocks, got {len(blocks_data)}" f"Expected {expected_blocks_count} blocks, got {len(blocks_data)}"
) )
expected_block_len = expected_tensor_size // 2 num_blocks = 2
if is_blocks_first:
expected_block_len = expected_tensor_size // num_blocks // 2
else:
expected_block_len = expected_tensor_size // num_blocks
for i, block_entry in enumerate(blocks_data): for i, block_entry in enumerate(blocks_data):
block_start_addr, block_len, tp_rank = block_entry block_start_addr, block_len, tp_rank = block_entry
assert block_len == expected_block_len, ( assert block_len == expected_block_len, (

View File

@ -6,8 +6,10 @@ import pytest
import torch import torch
from vllm.attention.backends.abstract import MultipleOf from vllm.attention.backends.abstract import MultipleOf
from vllm.attention.backends.registry import AttentionBackendEnum
from vllm.attention.layer import Attention from vllm.attention.layer import Attention
from vllm.config import ( from vllm.config import (
AttentionConfig,
CacheConfig, CacheConfig,
ModelConfig, ModelConfig,
ParallelConfig, ParallelConfig,
@ -765,7 +767,7 @@ def test_init_kv_cache_with_kv_sharing_valid():
current_platform.is_rocm(), current_platform.is_rocm(),
reason="Attention backend FLASHINFER is not supported on ROCm.", reason="Attention backend FLASHINFER is not supported on ROCm.",
) )
def test_hybrid_attention_mamba_tensor_shapes(monkeypatch): def test_hybrid_attention_mamba_tensor_shapes():
""" """
The GPU model runner creates different views into the The GPU model runner creates different views into the
KVCacheTensors for the attention and mamba layers KVCacheTensors for the attention and mamba layers
@ -806,11 +808,13 @@ def test_hybrid_attention_mamba_tensor_shapes(monkeypatch):
cache_dtype="auto", cache_dtype="auto",
) )
parallel_config = ParallelConfig() parallel_config = ParallelConfig()
attention_config = AttentionConfig(backend=AttentionBackendEnum.FLASHINFER)
vllm_config = VllmConfig( vllm_config = VllmConfig(
model_config=model_config, model_config=model_config,
cache_config=cache_config, cache_config=cache_config,
scheduler_config=scheduler_config, scheduler_config=scheduler_config,
parallel_config=parallel_config, parallel_config=parallel_config,
attention_config=attention_config,
) )
layer_0 = "model.layers.0.self_attn.attn" layer_0 = "model.layers.0.self_attn.attn"
@ -820,8 +824,7 @@ def test_hybrid_attention_mamba_tensor_shapes(monkeypatch):
layer_4 = "model.layers.4.mixer" layer_4 = "model.layers.4.mixer"
layer_5 = "model.layers.5.mixer" layer_5 = "model.layers.5.mixer"
with set_current_vllm_config(vllm_config), monkeypatch.context() as m: with set_current_vllm_config(vllm_config):
m.setenv("VLLM_ATTENTION_BACKEND", "FLASHINFER")
hf_config = vllm_config.model_config.hf_config hf_config = vllm_config.model_config.hf_config
fwd_context = {} fwd_context = {}
for key in [layer_0, layer_1]: for key in [layer_0, layer_1]:
@ -851,10 +854,7 @@ def test_hybrid_attention_mamba_tensor_shapes(monkeypatch):
) )
# suppress var not used error # suppress var not used error
assert fwd_context is not None assert fwd_context is not None
vllm_ctx = vllm_config.compilation_config.static_forward_context vllm_ctx = vllm_config.compilation_config.static_forward_context
with monkeypatch.context() as m:
m.setenv("VLLM_ATTENTION_BACKEND", "FLASHINFER")
runner = GPUModelRunner(vllm_config, DEVICE) runner = GPUModelRunner(vllm_config, DEVICE)
kv_cache_spec = runner.get_kv_cache_spec() kv_cache_spec = runner.get_kv_cache_spec()
@ -865,94 +865,94 @@ def test_hybrid_attention_mamba_tensor_shapes(monkeypatch):
)[0] )[0]
runner.initialize_kv_cache(kv_cache_config) runner.initialize_kv_cache(kv_cache_config)
# random partition of blocks # random partition of blocks
# blocks0 will be assigned to attention layers # blocks0 will be assigned to attention layers
# blocks1 will be assigned to mamba layers # blocks1 will be assigned to mamba layers
num_blocks = kv_cache_config.num_blocks num_blocks = kv_cache_config.num_blocks
ind = np.arange(num_blocks) ind = np.arange(num_blocks)
np.random.shuffle(ind) np.random.shuffle(ind)
blocks0, blocks1 = ind[: (num_blocks // 2)], ind[(num_blocks // 2) :] blocks0, blocks1 = ind[: (num_blocks // 2)], ind[(num_blocks // 2) :]
attn_shape = vllm_ctx[layer_0].kv_cache[0].shape attn_shape = vllm_ctx[layer_0].kv_cache[0].shape
conv_shape = vllm_ctx[layer_2].kv_cache[0][0].shape conv_shape = vllm_ctx[layer_2].kv_cache[0][0].shape
ssm_shape = vllm_ctx[layer_2].kv_cache[0][1].shape ssm_shape = vllm_ctx[layer_2].kv_cache[0][1].shape
# assert we are using FlashInfer # assert we are using FlashInfer
assert attn_shape[0] % num_blocks == 0 assert attn_shape[0] % num_blocks == 0
block_split_ratio = attn_shape[0] // num_blocks block_split_ratio = attn_shape[0] // num_blocks
# use small blocks for testing to avoid memory issues # use small blocks for testing to avoid memory issues
test_block_size = min(2, len(blocks0), len(blocks1)) test_block_size = min(2, len(blocks0), len(blocks1))
# use non-overlapping blocks to avoid data contamination # use non-overlapping blocks to avoid data contamination
# Split kernel blocks: first half for attention, second half for mamba # Split kernel blocks: first half for attention, second half for mamba
mid_point = num_blocks // 2 mid_point = num_blocks // 2
# attention uses kernel blocks from first half (mapped to logical blocks) # attention uses kernel blocks from first half (mapped to logical blocks)
kv_blocks_for_attention = np.array([0, 1])[:test_block_size] kv_blocks_for_attention = np.array([0, 1])[:test_block_size]
# mamba uses kernel blocks from second half # mamba uses kernel blocks from second half
kv_blocks_for_mamba = np.array([mid_point, mid_point + 1])[:test_block_size] kv_blocks_for_mamba = np.array([mid_point, mid_point + 1])[:test_block_size]
# create small constant tensors for testing with corrected shapes # create small constant tensors for testing with corrected shapes
# attention: [block_size, ...] starting from dimension 2 # attention: [block_size, ...] starting from dimension 2
attn_constant_shape = attn_shape[2:] attn_constant_shape = attn_shape[2:]
conv_constant_shape = conv_shape[1:] conv_constant_shape = conv_shape[1:]
ssm_constant_shape = ssm_shape[1:] ssm_constant_shape = ssm_shape[1:]
attn_blocks_constant = torch.full( attn_blocks_constant = torch.full(
(test_block_size, *attn_constant_shape), device=DEVICE, fill_value=3.33 (test_block_size, *attn_constant_shape), device=DEVICE, fill_value=3.33
) )
conv_blocks_constant = torch.full( conv_blocks_constant = torch.full(
(test_block_size, *conv_constant_shape), device=DEVICE, fill_value=6.66 (test_block_size, *conv_constant_shape), device=DEVICE, fill_value=6.66
) )
ssm_blocks_constant = torch.full( ssm_blocks_constant = torch.full(
(test_block_size, *ssm_constant_shape), device=DEVICE, fill_value=9.99 (test_block_size, *ssm_constant_shape), device=DEVICE, fill_value=9.99
) )
# Fill attention blocks with constants using kv block indices # Fill attention blocks with constants using kv block indices
kernel_blocks_for_attention = kv_blocks_for_attention * block_split_ratio kernel_blocks_for_attention = kv_blocks_for_attention * block_split_ratio
for layer in [layer_0, layer_1]: for layer in [layer_0, layer_1]:
# attention: kv_cache[0][kernel_block_idx, kv_idx, ...] # attention: kv_cache[0][kernel_block_idx, kv_idx, ...]
for i, kernel_block in enumerate(kernel_blocks_for_attention): for i, kernel_block in enumerate(kernel_blocks_for_attention):
vllm_ctx[layer].kv_cache[0][kernel_block, :] = attn_blocks_constant[i] vllm_ctx[layer].kv_cache[0][kernel_block, :] = attn_blocks_constant[i]
# fill mamba blocks with constants using kernel block indices # fill mamba blocks with constants using kernel block indices
for layer in [layer_2, layer_3, layer_4, layer_5]: for layer in [layer_2, layer_3, layer_4, layer_5]:
# mamba: kv_cache[0][component][kernel_block_idx, ...] # mamba: kv_cache[0][component][kernel_block_idx, ...]
for i, kv_block in enumerate(kv_blocks_for_mamba): for i, kv_block in enumerate(kv_blocks_for_mamba):
vllm_ctx[layer].kv_cache[0][0][kv_block, :] = conv_blocks_constant[i] vllm_ctx[layer].kv_cache[0][0][kv_block, :] = conv_blocks_constant[i]
vllm_ctx[layer].kv_cache[0][1][kv_block, :] = ssm_blocks_constant[i] vllm_ctx[layer].kv_cache[0][1][kv_block, :] = ssm_blocks_constant[i]
# verify attention and mamba contents are correct # verify attention and mamba contents are correct
for layer in [layer_0, layer_1]: for layer in [layer_0, layer_1]:
for i, kernel_block in enumerate(kernel_blocks_for_attention): for i, kernel_block in enumerate(kernel_blocks_for_attention):
actual_kv = vllm_ctx[layer].kv_cache[0][kernel_block, :] actual_kv = vllm_ctx[layer].kv_cache[0][kernel_block, :]
expected = attn_blocks_constant[i] expected = attn_blocks_constant[i]
# Check K and V separately # Check K and V separately
assert torch.equal(actual_kv[0], expected) assert torch.equal(actual_kv[0], expected)
assert torch.equal(actual_kv[1], expected) assert torch.equal(actual_kv[1], expected)
for layer in [layer_2, layer_3, layer_4, layer_5]: for layer in [layer_2, layer_3, layer_4, layer_5]:
for i, kv_block in enumerate(kv_blocks_for_mamba): for i, kv_block in enumerate(kv_blocks_for_mamba):
actual_conv = vllm_ctx[layer].kv_cache[0][0][kv_block, :] actual_conv = vllm_ctx[layer].kv_cache[0][0][kv_block, :]
actual_ssm = vllm_ctx[layer].kv_cache[0][1][kv_block, :] actual_ssm = vllm_ctx[layer].kv_cache[0][1][kv_block, :]
expected_conv = conv_blocks_constant[i] expected_conv = conv_blocks_constant[i]
expected_ssm = ssm_blocks_constant[i] expected_ssm = ssm_blocks_constant[i]
assert torch.equal(actual_conv, expected_conv) assert torch.equal(actual_conv, expected_conv)
assert torch.equal(actual_ssm, expected_ssm) assert torch.equal(actual_ssm, expected_ssm)
for layer in [layer_2, layer_3, layer_4, layer_5]: for layer in [layer_2, layer_3, layer_4, layer_5]:
for i, kv_block in enumerate(kv_blocks_for_mamba): for i, kv_block in enumerate(kv_blocks_for_mamba):
actual_conv = vllm_ctx[layer].kv_cache[0][0][kv_block, :] actual_conv = vllm_ctx[layer].kv_cache[0][0][kv_block, :]
actual_ssm = vllm_ctx[layer].kv_cache[0][1][kv_block, :] actual_ssm = vllm_ctx[layer].kv_cache[0][1][kv_block, :]
expected_conv = conv_blocks_constant[i] expected_conv = conv_blocks_constant[i]
expected_ssm = ssm_blocks_constant[i] expected_ssm = ssm_blocks_constant[i]
assert torch.equal(actual_conv, expected_conv) assert torch.equal(actual_conv, expected_conv)
assert torch.equal(actual_ssm, expected_ssm) assert torch.equal(actual_ssm, expected_ssm)
def test_hybrid_block_table_initialization(): def test_hybrid_block_table_initialization():

View File

@ -289,6 +289,16 @@ class AttentionImpl(ABC, Generic[T]):
# even if they can return lse (for efficiency reasons) # even if they can return lse (for efficiency reasons)
need_to_return_lse_for_decode: bool = False need_to_return_lse_for_decode: bool = False
# Whether this attention implementation supports pre-quantized query input.
# When True, the attention layer will quantize queries before passing them
# to this backend, allowing torch.compile to fuse the quantization with
# previous operations. This is typically supported when using FP8 KV cache
# with compatible attention kernels (e.g., TRT-LLM).
# Subclasses should set this in __init__.
# TODO add support to more backends:
# https://github.com/vllm-project/vllm/issues/25584
supports_quant_query_input: bool = False
dcp_world_size: int dcp_world_size: int
dcp_rank: int dcp_rank: int
@ -368,22 +378,6 @@ class AttentionImpl(ABC, Generic[T]):
""" """
return False return False
def supports_quant_query_input(self) -> bool:
"""
Check if this attention implementation supports pre-quantized query input.
When True, the attention layer will quantize queries before passing them
to this backend, allowing torch.compile to fuse the quantization with
previous operations. This is typically supported when using FP8 KV cache
with compatible attention kernels (e.g., TRT-LLM).
TODO add support to more backends:
https://github.com/vllm-project/vllm/issues/25584
Returns:
bool: True if the implementation can accept pre-quantized queries.
"""
return False
def process_weights_after_loading(self, act_dtype: torch.dtype): def process_weights_after_loading(self, act_dtype: torch.dtype):
pass pass

View File

@ -303,7 +303,7 @@ class Attention(nn.Module, AttentionLayerBase):
self.query_quant = None self.query_quant = None
if ( if (
self.kv_cache_dtype.startswith("fp8") self.kv_cache_dtype.startswith("fp8")
and self.impl.supports_quant_query_input() and self.impl.supports_quant_query_input
): ):
self.query_quant = QuantFP8(static=True, group_shape=GroupShape.PER_TENSOR) self.query_quant = QuantFP8(static=True, group_shape=GroupShape.PER_TENSOR)
@ -338,7 +338,7 @@ class Attention(nn.Module, AttentionLayerBase):
assert self.kv_cache_dtype in {"fp8", "fp8_e4m3"} assert self.kv_cache_dtype in {"fp8", "fp8_e4m3"}
# check if query quantization is supported # check if query quantization is supported
if self.impl.supports_quant_query_input(): if self.impl.supports_quant_query_input:
query, _ = self.query_quant(query, self._q_scale) query, _ = self.query_quant(query, self._q_scale)
if self.use_output: if self.use_output:

View File

@ -2,19 +2,14 @@
# SPDX-FileCopyrightText: Copyright contributors to the vLLM project # SPDX-FileCopyrightText: Copyright contributors to the vLLM project
import inspect import inspect
import os
from collections.abc import Generator
from contextlib import contextmanager
from functools import cache from functools import cache
from typing import cast, get_args from typing import cast, get_args
import torch import torch
import vllm.envs as envs
from vllm.attention.backends.abstract import AttentionBackend from vllm.attention.backends.abstract import AttentionBackend
from vllm.attention.backends.registry import ( from vllm.attention.backends.registry import (
MAMBA_TYPE_TO_BACKEND_MAP, MAMBA_TYPE_TO_BACKEND_MAP,
AttentionBackendEnum,
MambaAttentionBackendEnum, MambaAttentionBackendEnum,
) )
from vllm.config.cache import CacheDType from vllm.config.cache import CacheDType
@ -24,60 +19,6 @@ from vllm.utils.import_utils import resolve_obj_by_qualname
logger = init_logger(__name__) logger = init_logger(__name__)
def get_env_variable_attn_backend() -> AttentionBackendEnum | None:
"""
Get the backend override specified by the vLLM attention
backend environment variable, if one is specified.
Returns:
* AttentionBackendEnum value if an override is specified
* None otherwise
"""
backend_name = os.environ.get("VLLM_ATTENTION_BACKEND")
if backend_name is None:
return None
if backend_name == "XFORMERS":
raise ValueError(
"Attention backend 'XFORMERS' has been removed (See PR #29262 for "
"details). Please select a supported attention backend."
)
return AttentionBackendEnum[backend_name]
# Global state allows a particular choice of backend
# to be forced, overriding the logic which auto-selects
# a backend based on system & workload configuration
# (default behavior if this variable is None)
#
# THIS SELECTION TAKES PRECEDENCE OVER THE
# VLLM_ATTENTION_BACKEND ENVIRONMENT VARIABLE
forced_attn_backend: AttentionBackendEnum | None = None
def global_force_attn_backend(attn_backend: AttentionBackendEnum | None) -> None:
"""
Force all attention operations to use a specified backend.
Passing `None` for the argument re-enables automatic
backend selection.,
Arguments:
* attn_backend: backend selection (None to revert to auto)
"""
global forced_attn_backend
forced_attn_backend = attn_backend
def get_global_forced_attn_backend() -> AttentionBackendEnum | None:
"""
Get the currently-forced choice of attention backend,
or None if auto-selection is currently enabled.
"""
return forced_attn_backend
def get_attn_backend( def get_attn_backend(
head_size: int, head_size: int,
dtype: torch.dtype, dtype: torch.dtype,
@ -97,7 +38,13 @@ def get_attn_backend(
f"Valid values are: {valid_cache_dtypes}" f"Valid values are: {valid_cache_dtypes}"
) )
from vllm.config import get_current_vllm_config
vllm_config = get_current_vllm_config()
backend_enum = vllm_config.attention_config.backend
return _cached_get_attn_backend( return _cached_get_attn_backend(
backend=backend_enum,
head_size=head_size, head_size=head_size,
dtype=dtype, dtype=dtype,
kv_cache_dtype=cast(CacheDType | None, kv_cache_dtype), kv_cache_dtype=cast(CacheDType | None, kv_cache_dtype),
@ -111,6 +58,7 @@ def get_attn_backend(
@cache @cache
def _cached_get_attn_backend( def _cached_get_attn_backend(
backend,
head_size: int, head_size: int,
dtype: torch.dtype, dtype: torch.dtype,
kv_cache_dtype: CacheDType | None, kv_cache_dtype: CacheDType | None,
@ -120,39 +68,6 @@ def _cached_get_attn_backend(
use_sparse: bool = False, use_sparse: bool = False,
attn_type: str | None = None, attn_type: str | None = None,
) -> type[AttentionBackend]: ) -> type[AttentionBackend]:
# Check whether a particular choice of backend was
# previously forced.
#
# THIS SELECTION OVERRIDES THE VLLM_ATTENTION_BACKEND
# ENVIRONMENT VARIABLE.
selected_backend = None
backend_by_global_setting: AttentionBackendEnum | None = (
get_global_forced_attn_backend()
)
if backend_by_global_setting is not None:
selected_backend = backend_by_global_setting
else:
# Check the environment variable and override if specified
backend_by_env_var: str | None = envs.VLLM_ATTENTION_BACKEND
if backend_by_env_var is not None:
if backend_by_env_var.endswith("_VLLM_V1"):
logger.warning(
"The suffix '_VLLM_V1' in the environment variable "
"VLLM_ATTENTION_BACKEND is no longer necessary as "
"V0 backends have been deprecated. "
"Please remove this suffix from your "
"environment variable setting.",
)
backend_by_env_var = backend_by_env_var.removesuffix("_VLLM_V1")
try:
selected_backend = AttentionBackendEnum[backend_by_env_var]
except KeyError as e:
raise ValueError(
f"Invalid attention backend: '{backend_by_env_var}'. Valid "
f"backends are: {list(AttentionBackendEnum.__members__.keys())}"
) from e
# get device-specific attn_backend
from vllm.platforms import current_platform from vllm.platforms import current_platform
sig = inspect.signature(current_platform.get_attn_backend_cls) sig = inspect.signature(current_platform.get_attn_backend_cls)
@ -163,7 +78,7 @@ def _cached_get_attn_backend(
"remove it from your plugin code." "remove it from your plugin code."
) )
attention_cls = current_platform.get_attn_backend_cls( attention_cls = current_platform.get_attn_backend_cls(
selected_backend, backend,
head_size, head_size,
dtype, dtype,
kv_cache_dtype, kv_cache_dtype,
@ -176,7 +91,7 @@ def _cached_get_attn_backend(
) )
else: else:
attention_cls = current_platform.get_attn_backend_cls( attention_cls = current_platform.get_attn_backend_cls(
selected_backend, backend,
head_size, head_size,
dtype, dtype,
kv_cache_dtype, kv_cache_dtype,
@ -232,37 +147,3 @@ def _cached_get_mamba_attn_backend(
mamba_attn_backend = selected_backend.get_class() mamba_attn_backend = selected_backend.get_class()
return mamba_attn_backend return mamba_attn_backend
@contextmanager
def global_force_attn_backend_context_manager(
attn_backend: AttentionBackendEnum,
) -> Generator[None, None, None]:
"""
Globally force a vLLM attention backend override within a
context manager, reverting the global attention backend
override to its prior state upon exiting the context
manager.
Arguments:
* attn_backend: attention backend to force
Returns:
* Generator
"""
# Save the current state of the global backend override (if any)
original_value = get_global_forced_attn_backend()
# Globally force the new backend override
global_force_attn_backend(attn_backend)
# Yield control back to the enclosed code block
try:
yield
finally:
# Revert the original global backend override, if any
global_force_attn_backend(original_value)
_cached_get_attn_backend.cache_clear()

View File

@ -1,7 +1,6 @@
# SPDX-License-Identifier: Apache-2.0 # SPDX-License-Identifier: Apache-2.0
# SPDX-FileCopyrightText: Copyright contributors to the vLLM project # SPDX-FileCopyrightText: Copyright contributors to the vLLM project
from vllm import envs
from vllm.logger import init_logger from vllm.logger import init_logger
from vllm.platforms import current_platform from vllm.platforms import current_platform
@ -49,10 +48,12 @@ def get_flash_attn_version(requires_alibi: bool = False) -> int | None:
3 if (device_capability.major == 9 and is_fa_version_supported(3)) else 2 3 if (device_capability.major == 9 and is_fa_version_supported(3)) else 2
) )
# 2. override if passed by environment # 2. override if passed by environment or config
if envs.VLLM_FLASH_ATTN_VERSION is not None: from vllm.config import get_current_vllm_config
assert envs.VLLM_FLASH_ATTN_VERSION in [2, 3]
fa_version = envs.VLLM_FLASH_ATTN_VERSION vllm_config = get_current_vllm_config()
if vllm_config.attention_config.flash_attn_version is not None:
fa_version = vllm_config.attention_config.flash_attn_version
# 3. fallback for unsupported combinations # 3. fallback for unsupported combinations
if device_capability.major == 10 and fa_version == 3: if device_capability.major == 10 and fa_version == 3:

View File

@ -1,6 +1,7 @@
# SPDX-License-Identifier: Apache-2.0 # SPDX-License-Identifier: Apache-2.0
# SPDX-FileCopyrightText: Copyright contributors to the vLLM project # SPDX-FileCopyrightText: Copyright contributors to the vLLM project
from vllm.config.attention import AttentionConfig
from vllm.config.cache import CacheConfig from vllm.config.cache import CacheConfig
from vllm.config.compilation import ( from vllm.config.compilation import (
CompilationConfig, CompilationConfig,
@ -46,6 +47,8 @@ from vllm.config.vllm import (
# __all__ should only contain classes and functions. # __all__ should only contain classes and functions.
# Types and globals should be imported from their respective modules. # Types and globals should be imported from their respective modules.
__all__ = [ __all__ = [
# From vllm.config.attention
"AttentionConfig",
# From vllm.config.cache # From vllm.config.cache
"CacheConfig", "CacheConfig",
# From vllm.config.compilation # From vllm.config.compilation

114
vllm/config/attention.py Normal file
View File

@ -0,0 +1,114 @@
# SPDX-License-Identifier: Apache-2.0
# SPDX-FileCopyrightText: Copyright contributors to the vLLM project
from typing import Any, Literal
from pydantic import field_validator
from pydantic.dataclasses import dataclass
from vllm.attention.backends.registry import AttentionBackendEnum
from vllm.config.utils import config
from vllm.logger import init_logger
logger = init_logger(__name__)
@config
@dataclass
class AttentionConfig:
"""Configuration for attention mechanisms in vLLM."""
backend: AttentionBackendEnum | None = None
"""Attention backend to use. If None, will be selected automatically."""
flash_attn_version: Literal[2, 3] | None = None
"""Force vllm to use a specific flash-attention version (2 or 3).
Only valid when using the flash-attention backend."""
use_prefill_decode_attention: bool = False
"""Use separate prefill and decode kernels for attention instead of
the unified triton kernel."""
flash_attn_max_num_splits_for_cuda_graph: int = 32
"""Flash Attention max number splits for cuda graph decode."""
use_cudnn_prefill: bool = False
"""Whether to use cudnn prefill."""
use_trtllm_ragged_deepseek_prefill: bool = False
"""Whether to use TRTLLM ragged deepseek prefill."""
use_trtllm_attention: bool | None = None
"""If set to True/False, use or don't use the TRTLLM attention backend
in flashinfer. If None, auto-detect the attention backend in flashinfer."""
disable_flashinfer_prefill: bool = False
"""Whether to disable flashinfer prefill."""
disable_flashinfer_q_quantization: bool = False
"""If set, when using fp8 kv, do not quantize Q to fp8."""
def compute_hash(self) -> str:
"""
Provide a hash that uniquely identifies all the configs
that affect the structure of the computation
graph from input ids/embeddings to the final hidden states,
excluding anything before input ids/embeddings and after
the final hidden states.
"""
from vllm.config.utils import get_hash_factors, hash_factors
ignored_factors: list[str] = []
factors = get_hash_factors(self, ignored_factors)
return hash_factors(factors)
@field_validator("backend", mode="before")
@classmethod
def validate_backend_before(cls, value: Any) -> Any:
"""Enable parsing of the `backend` enum type from string."""
if isinstance(value, str):
return AttentionBackendEnum[value.upper()]
return value
def _set_from_env_if_set(self, field_name: str, env_var_name: str) -> None:
"""Set field from env var if set, with deprecation warning."""
from vllm import envs
if envs.is_set(env_var_name):
value = getattr(envs, env_var_name)
if field_name == "backend":
value = self.validate_backend_before(value)
setattr(self, field_name, value)
logger.warning_once(
"Using %s environment variable is deprecated and will be removed in "
"v0.14.0 or v1.0.0, whichever is soonest. Please use "
"--attention-config.%s command line argument or "
"AttentionConfig(%s=...) config field instead.",
env_var_name,
field_name,
field_name,
)
def __post_init__(self) -> None:
self._set_from_env_if_set("backend", "VLLM_ATTENTION_BACKEND")
self._set_from_env_if_set("flash_attn_version", "VLLM_FLASH_ATTN_VERSION")
self._set_from_env_if_set(
"use_prefill_decode_attention", "VLLM_V1_USE_PREFILL_DECODE_ATTENTION"
)
self._set_from_env_if_set(
"flash_attn_max_num_splits_for_cuda_graph",
"VLLM_FLASH_ATTN_MAX_NUM_SPLITS_FOR_CUDA_GRAPH",
)
self._set_from_env_if_set("use_cudnn_prefill", "VLLM_USE_CUDNN_PREFILL")
self._set_from_env_if_set(
"use_trtllm_ragged_deepseek_prefill",
"VLLM_USE_TRTLLM_RAGGED_DEEPSEEK_PREFILL",
)
self._set_from_env_if_set("use_trtllm_attention", "VLLM_USE_TRTLLM_ATTENTION")
self._set_from_env_if_set(
"disable_flashinfer_prefill", "VLLM_DISABLE_FLASHINFER_PREFILL"
)
self._set_from_env_if_set(
"disable_flashinfer_q_quantization",
"VLLM_FLASHINFER_DISABLE_Q_QUANTIZATION",
)

View File

@ -4,7 +4,6 @@
import warnings import warnings
from collections.abc import Callable from collections.abc import Callable
from dataclasses import InitVar, field from dataclasses import InitVar, field
from importlib.util import find_spec
from typing import TYPE_CHECKING, Any, Literal, cast, get_args from typing import TYPE_CHECKING, Any, Literal, cast, get_args
import torch import torch
@ -467,18 +466,6 @@ class ModelConfig:
self.maybe_pull_model_tokenizer_for_runai(self.model, self.tokenizer) self.maybe_pull_model_tokenizer_for_runai(self.model, self.tokenizer)
if (
(backend := envs.VLLM_ATTENTION_BACKEND)
and backend == "FLASHINFER"
and find_spec("flashinfer") is None
):
raise ValueError(
"VLLM_ATTENTION_BACKEND is set to FLASHINFER, but flashinfer "
"module was not found. See "
"https://github.com/vllm-project/vllm/blob/main/docker/Dockerfile " # noqa: E501
"for instructions on how to install it."
)
from vllm.platforms import current_platform from vllm.platforms import current_platform
if self.override_attention_dtype is not None and not current_platform.is_rocm(): if self.override_attention_dtype is not None and not current_platform.is_rocm():

View File

@ -27,6 +27,7 @@ from vllm.transformers_utils.runai_utils import is_runai_obj_uri
from vllm.utils import random_uuid from vllm.utils import random_uuid
from vllm.utils.hashing import safe_hash from vllm.utils.hashing import safe_hash
from .attention import AttentionConfig
from .cache import CacheConfig from .cache import CacheConfig
from .compilation import CompilationConfig, CompilationMode, CUDAGraphMode from .compilation import CompilationConfig, CompilationMode, CUDAGraphMode
from .device import DeviceConfig from .device import DeviceConfig
@ -192,6 +193,8 @@ class VllmConfig:
"""Device configuration.""" """Device configuration."""
load_config: LoadConfig = Field(default_factory=LoadConfig) load_config: LoadConfig = Field(default_factory=LoadConfig)
"""Load configuration.""" """Load configuration."""
attention_config: AttentionConfig = Field(default_factory=AttentionConfig)
"""Attention configuration."""
lora_config: LoRAConfig | None = None lora_config: LoRAConfig | None = None
"""LoRA configuration.""" """LoRA configuration."""
speculative_config: SpeculativeConfig | None = None speculative_config: SpeculativeConfig | None = None
@ -279,6 +282,10 @@ class VllmConfig:
vllm_factors.append(self.load_config.compute_hash()) vllm_factors.append(self.load_config.compute_hash())
else: else:
vllm_factors.append("None") vllm_factors.append("None")
if self.attention_config:
vllm_factors.append(self.attention_config.compute_hash())
else:
vllm_factors.append("None")
if self.lora_config: if self.lora_config:
vllm_factors.append(self.lora_config.compute_hash()) vllm_factors.append(self.lora_config.compute_hash())
else: else:

View File

@ -34,6 +34,7 @@ from typing_extensions import TypeIs
import vllm.envs as envs import vllm.envs as envs
from vllm.attention.backends.registry import AttentionBackendEnum from vllm.attention.backends.registry import AttentionBackendEnum
from vllm.config import ( from vllm.config import (
AttentionConfig,
CacheConfig, CacheConfig,
CompilationConfig, CompilationConfig,
ConfigType, ConfigType,
@ -527,6 +528,7 @@ class EngineArgs:
pooler_config: PoolerConfig | None = ModelConfig.pooler_config pooler_config: PoolerConfig | None = ModelConfig.pooler_config
compilation_config: CompilationConfig = get_field(VllmConfig, "compilation_config") compilation_config: CompilationConfig = get_field(VllmConfig, "compilation_config")
attention_config: AttentionConfig = get_field(VllmConfig, "attention_config")
worker_cls: str = ParallelConfig.worker_cls worker_cls: str = ParallelConfig.worker_cls
worker_extension_cls: str = ParallelConfig.worker_extension_cls worker_extension_cls: str = ParallelConfig.worker_extension_cls
@ -542,6 +544,7 @@ class EngineArgs:
) )
model_impl: str = ModelConfig.model_impl model_impl: str = ModelConfig.model_impl
override_attention_dtype: str = ModelConfig.override_attention_dtype override_attention_dtype: str = ModelConfig.override_attention_dtype
attention_backend: AttentionBackendEnum | None = AttentionConfig.backend
calculate_kv_scales: bool = CacheConfig.calculate_kv_scales calculate_kv_scales: bool = CacheConfig.calculate_kv_scales
mamba_cache_dtype: MambaDType = CacheConfig.mamba_cache_dtype mamba_cache_dtype: MambaDType = CacheConfig.mamba_cache_dtype
@ -580,6 +583,8 @@ class EngineArgs:
# CompilationConfig object # CompilationConfig object
if isinstance(self.compilation_config, dict): if isinstance(self.compilation_config, dict):
self.compilation_config = CompilationConfig(**self.compilation_config) self.compilation_config = CompilationConfig(**self.compilation_config)
if isinstance(self.attention_config, dict):
self.attention_config = AttentionConfig(**self.attention_config)
if isinstance(self.eplb_config, dict): if isinstance(self.eplb_config, dict):
self.eplb_config = EPLBConfig(**self.eplb_config) self.eplb_config = EPLBConfig(**self.eplb_config)
# Setup plugins # Setup plugins
@ -717,6 +722,16 @@ class EngineArgs:
"--pt-load-map-location", **load_kwargs["pt_load_map_location"] "--pt-load-map-location", **load_kwargs["pt_load_map_location"]
) )
# Attention arguments
attention_kwargs = get_kwargs(AttentionConfig)
attention_group = parser.add_argument_group(
title="AttentionConfig",
description=AttentionConfig.__doc__,
)
attention_group.add_argument(
"--attention-backend", **attention_kwargs["backend"]
)
# Structured outputs arguments # Structured outputs arguments
structured_outputs_kwargs = get_kwargs(StructuredOutputsConfig) structured_outputs_kwargs = get_kwargs(StructuredOutputsConfig)
structured_outputs_group = parser.add_argument_group( structured_outputs_group = parser.add_argument_group(
@ -1140,6 +1155,9 @@ class EngineArgs:
vllm_group.add_argument( vllm_group.add_argument(
"--compilation-config", "-cc", **vllm_kwargs["compilation_config"] "--compilation-config", "-cc", **vllm_kwargs["compilation_config"]
) )
vllm_group.add_argument(
"--attention-config", "-ac", **vllm_kwargs["attention_config"]
)
vllm_group.add_argument( vllm_group.add_argument(
"--additional-config", **vllm_kwargs["additional_config"] "--additional-config", **vllm_kwargs["additional_config"]
) )
@ -1693,6 +1711,16 @@ class EngineArgs:
if model_config.quantization == "bitsandbytes": if model_config.quantization == "bitsandbytes":
self.quantization = self.load_format = "bitsandbytes" self.quantization = self.load_format = "bitsandbytes"
# Attention config overrides
attention_config = copy.deepcopy(self.attention_config)
if self.attention_backend is not None:
if attention_config.backend is not None:
raise ValueError(
"attention_backend and attention_config.backend "
"are mutually exclusive"
)
attention_config.backend = self.attention_backend
load_config = self.create_load_config() load_config = self.create_load_config()
# Pass reasoning_parser into StructuredOutputsConfig # Pass reasoning_parser into StructuredOutputsConfig
@ -1750,9 +1778,10 @@ class EngineArgs:
parallel_config=parallel_config, parallel_config=parallel_config,
scheduler_config=scheduler_config, scheduler_config=scheduler_config,
device_config=device_config, device_config=device_config,
load_config=load_config,
attention_config=attention_config,
lora_config=lora_config, lora_config=lora_config,
speculative_config=speculative_config, speculative_config=speculative_config,
load_config=load_config,
structured_outputs_config=self.structured_outputs_config, structured_outputs_config=self.structured_outputs_config,
observability_config=observability_config, observability_config=observability_config,
compilation_config=compilation_config, compilation_config=compilation_config,

View File

@ -4,7 +4,7 @@ from copy import deepcopy
from math import lcm from math import lcm
from typing import TYPE_CHECKING from typing import TYPE_CHECKING
import vllm.envs as envs from vllm.attention.backends.registry import AttentionBackendEnum
from vllm.logger import init_logger from vllm.logger import init_logger
from vllm.model_executor.models import ModelRegistry from vllm.model_executor.models import ModelRegistry
from vllm.platforms import current_platform from vllm.platforms import current_platform
@ -331,6 +331,7 @@ class HybridAttentionMambaModelConfig(VerifyAndUpdateConfig):
# Enable FULL_AND_PIECEWISE by default # Enable FULL_AND_PIECEWISE by default
MambaModelConfig.verify_and_update_config(vllm_config) MambaModelConfig.verify_and_update_config(vllm_config)
attention_config = vllm_config.attention_config
cache_config = vllm_config.cache_config cache_config = vllm_config.cache_config
model_config = vllm_config.model_config model_config = vllm_config.model_config
parallel_config = vllm_config.parallel_config parallel_config = vllm_config.parallel_config
@ -347,7 +348,9 @@ class HybridAttentionMambaModelConfig(VerifyAndUpdateConfig):
# * CUTLASS_MLA backend: kernel_block_size 128 alignment # * CUTLASS_MLA backend: kernel_block_size 128 alignment
# * Other MLA backends: kernel_block_size 64 alignment # * Other MLA backends: kernel_block_size 64 alignment
if model_config.use_mla: if model_config.use_mla:
use_cutlass_mla = envs.VLLM_ATTENTION_BACKEND == "CUTLASS_MLA" use_cutlass_mla = (
attention_config.backend == AttentionBackendEnum.CUTLASS_MLA
)
kernel_block_alignment_size = 128 if use_cutlass_mla else 64 kernel_block_alignment_size = 128 if use_cutlass_mla else 64
attn_page_size_1_token = MLAAttentionSpec( attn_page_size_1_token = MLAAttentionSpec(
block_size=1, block_size=1,
@ -361,8 +364,8 @@ class HybridAttentionMambaModelConfig(VerifyAndUpdateConfig):
current_platform.is_device_capability(100) current_platform.is_device_capability(100)
and model_config.get_head_size() == 256 and model_config.get_head_size() == 256
and ( and (
envs.VLLM_ATTENTION_BACKEND is None attention_config.backend is None
or envs.VLLM_ATTENTION_BACKEND == "FLASHINFER" or attention_config.backend == AttentionBackendEnum.FLASHINFER
) )
): ):
# https://github.com/flashinfer-ai/flashinfer/issues/1993 reports that` # https://github.com/flashinfer-ai/flashinfer/issues/1993 reports that`

View File

@ -11,7 +11,7 @@ import torch
from transformers import PretrainedConfig from transformers import PretrainedConfig
from vllm.attention.backends.registry import AttentionBackendEnum from vllm.attention.backends.registry import AttentionBackendEnum
from vllm.config import VllmConfig from vllm.config import VllmConfig, get_current_vllm_config
from vllm.distributed import ( from vllm.distributed import (
get_tensor_model_parallel_rank, get_tensor_model_parallel_rank,
get_tensor_model_parallel_world_size, get_tensor_model_parallel_world_size,
@ -91,10 +91,7 @@ def get_vit_attn_backend(
if attn_backend_override is not None: if attn_backend_override is not None:
return attn_backend_override return attn_backend_override
# Lazy import to avoid circular dependency selected_backend = get_current_vllm_config().attention_config.backend
from vllm.attention.selector import get_env_variable_attn_backend
selected_backend: AttentionBackendEnum | None = get_env_variable_attn_backend()
if selected_backend is not None: if selected_backend is not None:
return selected_backend return selected_backend

View File

@ -14,7 +14,6 @@ from typing_extensions import ParamSpec
# import custom ops, trigger op registration # import custom ops, trigger op registration
import vllm._C # noqa import vllm._C # noqa
import vllm.envs as envs
from vllm.attention.backends.abstract import AttentionType from vllm.attention.backends.abstract import AttentionType
from vllm.attention.backends.registry import AttentionBackendEnum from vllm.attention.backends.registry import AttentionBackendEnum
from vllm.logger import init_logger from vllm.logger import init_logger
@ -149,6 +148,8 @@ class CudaPlatformBase(Platform):
@classmethod @classmethod
def check_and_update_config(cls, vllm_config: "VllmConfig") -> None: def check_and_update_config(cls, vllm_config: "VllmConfig") -> None:
from vllm.attention.backends.registry import AttentionBackendEnum
parallel_config = vllm_config.parallel_config parallel_config = vllm_config.parallel_config
model_config = vllm_config.model_config model_config = vllm_config.model_config
@ -171,7 +172,7 @@ class CudaPlatformBase(Platform):
and cache_config.block_size is not None and cache_config.block_size is not None
): ):
use_sparse = hasattr(vllm_config.model_config.hf_config, "index_topk") use_sparse = hasattr(vllm_config.model_config.hf_config, "index_topk")
# If `VLLM_ATTENTION_BACKEND` is not set and we are using MLA, # If `--attention-config.backend` is not set and we are using MLA,
# then we default to FlashMLA backend for non-blackwell GPUs, # then we default to FlashMLA backend for non-blackwell GPUs,
# else we default to CutlassMLA. For each case, we force the # else we default to CutlassMLA. For each case, we force the
# required block_size. # required block_size.
@ -179,23 +180,25 @@ class CudaPlatformBase(Platform):
use_cutlass_mla = False use_cutlass_mla = False
use_flashinfer_mla = False use_flashinfer_mla = False
if envs.VLLM_ATTENTION_BACKEND is None: if vllm_config.attention_config.backend is None:
# Default case # Default case
if cls.is_device_capability(100): if cls.is_device_capability(100):
# Blackwell => Force CutlassMLA. # Blackwell => Force CutlassMLA.
use_cutlass_mla = True use_cutlass_mla = True
# TODO: This does not work, because the # Set the backend in AttentionConfig so it's used during
# global_force_attn_backend_context_manager is not set. # backend selection
# See vllm/attention/selector.py:_cached_get_attn_backend vllm_config.attention_config.backend = (
envs.VLLM_ATTENTION_BACKEND = "CUTLASS_MLA" AttentionBackendEnum.CUTLASS_MLA
)
else: else:
# Not Blackwell # Not Blackwell
use_flashmla = True use_flashmla = True
else: else:
# Forced case # Forced case
use_flashmla = envs.VLLM_ATTENTION_BACKEND == "FLASHMLA" backend = vllm_config.attention_config.backend
use_cutlass_mla = envs.VLLM_ATTENTION_BACKEND == "CUTLASS_MLA" use_flashmla = backend == AttentionBackendEnum.FLASHMLA
use_flashinfer_mla = envs.VLLM_ATTENTION_BACKEND == "FLASHINFER_MLA" use_cutlass_mla = backend == AttentionBackendEnum.CUTLASS_MLA
use_flashinfer_mla = backend == AttentionBackendEnum.FLASHINFER_MLA
from vllm.attention.ops.flashmla import is_flashmla_dense_supported from vllm.attention.ops.flashmla import is_flashmla_dense_supported

View File

@ -267,21 +267,16 @@ def supports_trtllm_attention() -> bool:
return current_platform.is_device_capability(100) and has_nvidia_artifactory() return current_platform.is_device_capability(100) and has_nvidia_artifactory()
@functools.cache
def _force_use_trtllm_attention(env_value: bool | None) -> bool | None:
"""Cache the env value for VLLM_USE_TRTLLM_ATTENTION"""
if env_value is not None:
logger.info_once("VLLM_USE_TRTLLM_ATTENTION is set to %s", env_value)
return env_value
def force_use_trtllm_attention() -> bool | None: def force_use_trtllm_attention() -> bool | None:
""" """
Return `None` if VLLM_USE_TRTLLM_ATTENTION is not set, Return `None` if --attention-config.use_trtllm_attention is not set,
return `True` if TRTLLM attention is forced to be used, return `True` if TRTLLM attention is forced to be used,
return `False` if TRTLLM attention is forced to be not used. return `False` if TRTLLM attention is forced to be not used.
""" """
return _force_use_trtllm_attention(envs.VLLM_USE_TRTLLM_ATTENTION) from vllm.config import get_current_vllm_config
vllm_config = get_current_vllm_config()
return vllm_config.attention_config.use_trtllm_attention
def can_use_trtllm_attention(num_qo_heads: int, num_kv_heads: int) -> bool: def can_use_trtllm_attention(num_qo_heads: int, num_kv_heads: int) -> bool:
@ -307,7 +302,7 @@ def use_trtllm_attention(
"""Return `True` if TRTLLM attention is used.""" """Return `True` if TRTLLM attention is used."""
force_use_trtllm = force_use_trtllm_attention() force_use_trtllm = force_use_trtllm_attention()
# Environment variable is set to 0 - respect it # CLI argument is set to 0 - respect it
if force_use_trtllm is not None and not force_use_trtllm: if force_use_trtllm is not None and not force_use_trtllm:
return False return False
@ -324,7 +319,7 @@ def use_trtllm_attention(
if force_use_trtllm: if force_use_trtllm:
logger.warning_once( logger.warning_once(
"TRTLLM attention is not supported on this platform, " "TRTLLM attention is not supported on this platform, "
"but VLLM_USE_TRTLLM_ATTENTION is set to 1" "but --attention-config.use_trtllm_attention is set to 1"
) )
return False return False
@ -333,7 +328,8 @@ def use_trtllm_attention(
if force_use_trtllm: if force_use_trtllm:
logger.warning_once( logger.warning_once(
"TRTLLM attention is not supported for this combination of " "TRTLLM attention is not supported for this combination of "
"query and key heads, but VLLM_USE_TRTLLM_ATTENTION is set to 1" "query and key heads, but --attention-config.use_trtllm_attention is "
"set to 1"
) )
return False return False
@ -354,7 +350,7 @@ def use_trtllm_attention(
return True return True
if force_use_trtllm is None: if force_use_trtllm is None:
# Environment variable not set - use auto-detection # CLI argument not set - use auto-detection
if is_prefill: if is_prefill:
# Prefill auto-detection # Prefill auto-detection
use_trtllm = kv_cache_dtype == "auto" use_trtllm = kv_cache_dtype == "auto"
@ -367,8 +363,10 @@ def use_trtllm_attention(
logger.warning_once("Using TRTLLM decode attention (auto-detected).") logger.warning_once("Using TRTLLM decode attention (auto-detected).")
return use_trtllm return use_trtllm
# Environment variable is set to 1 - respect it # CLI argument is set to 1 - respect it
logger.info_once("Using TRTLLM attention (VLLM_USE_TRTLLM_ATTENTION is set to 1)") logger.info_once(
"Using TRTLLM attention (--attention-config.use_trtllm_attention is set to 1)"
)
return True return True
@ -500,12 +498,6 @@ def flashinfer_scaled_fp8_mm(
return output return output
@functools.cache
def flashinfer_disable_q_quantization() -> bool:
"""Cache result which only depends on the environment"""
return envs.VLLM_FLASHINFER_DISABLE_Q_QUANTIZATION
__all__ = [ __all__ = [
"has_flashinfer", "has_flashinfer",
"flashinfer_trtllm_fp8_block_scale_moe", "flashinfer_trtllm_fp8_block_scale_moe",
@ -526,7 +518,6 @@ __all__ = [
"supports_trtllm_attention", "supports_trtllm_attention",
"can_use_trtllm_attention", "can_use_trtllm_attention",
"use_trtllm_attention", "use_trtllm_attention",
"flashinfer_disable_q_quantization",
"flashinfer_scaled_fp4_mm", "flashinfer_scaled_fp4_mm",
"flashinfer_scaled_fp8_mm", "flashinfer_scaled_fp8_mm",
] ]

View File

@ -8,7 +8,6 @@ from typing import ClassVar
import numpy as np import numpy as np
import torch import torch
from vllm import envs
from vllm.attention.backends.abstract import ( from vllm.attention.backends.abstract import (
AttentionBackend, AttentionBackend,
AttentionImpl, AttentionImpl,
@ -264,6 +263,7 @@ class FlashAttentionMetadataBuilder(AttentionMetadataBuilder[FlashAttentionMetad
self.parallel_config = vllm_config.parallel_config self.parallel_config = vllm_config.parallel_config
self.cache_config = vllm_config.cache_config self.cache_config = vllm_config.cache_config
self.compilation_config = vllm_config.compilation_config self.compilation_config = vllm_config.compilation_config
self.attention_config = vllm_config.attention_config
self.num_heads_q = self.model_config.get_num_attention_heads( self.num_heads_q = self.model_config.get_num_attention_heads(
self.parallel_config self.parallel_config
@ -304,7 +304,9 @@ class FlashAttentionMetadataBuilder(AttentionMetadataBuilder[FlashAttentionMetad
# When using cuda graph, we need to set the upper bound of the # When using cuda graph, we need to set the upper bound of the
# number of splits so that large enough intermediate buffers are # number of splits so that large enough intermediate buffers are
# pre-allocated during capture. # pre-allocated during capture.
self.max_num_splits = envs.VLLM_FLASH_ATTN_MAX_NUM_SPLITS_FOR_CUDA_GRAPH self.max_num_splits = (
self.attention_config.flash_attn_max_num_splits_for_cuda_graph
)
# Sliding window size to be used with the AOT scheduler will be # Sliding window size to be used with the AOT scheduler will be
# populated on first build() call. # populated on first build() call.
@ -554,8 +556,7 @@ class FlashAttentionImpl(AttentionImpl):
"heads in the layer" "heads in the layer"
) )
def supports_quant_query_input(self) -> bool: self.supports_quant_query_input = True
return True
def forward( def forward(
self, self,

View File

@ -26,7 +26,7 @@ from vllm.attention.backends.abstract import (
) )
from vllm.attention.ops.common import cp_lse_ag_out_rs from vllm.attention.ops.common import cp_lse_ag_out_rs
from vllm.attention.ops.merge_attn_states import merge_attn_states from vllm.attention.ops.merge_attn_states import merge_attn_states
from vllm.config import CUDAGraphMode, VllmConfig from vllm.config import CUDAGraphMode, VllmConfig, get_current_vllm_config
from vllm.config.cache import CacheDType from vllm.config.cache import CacheDType
from vllm.distributed.parallel_state import get_dcp_group from vllm.distributed.parallel_state import get_dcp_group
from vllm.logger import init_logger from vllm.logger import init_logger
@ -43,7 +43,6 @@ from vllm.platforms.interface import DeviceCapability
from vllm.triton_utils import tl, triton from vllm.triton_utils import tl, triton
from vllm.utils.flashinfer import ( from vllm.utils.flashinfer import (
can_use_trtllm_attention, can_use_trtllm_attention,
flashinfer_disable_q_quantization,
use_trtllm_attention, use_trtllm_attention,
) )
from vllm.utils.math_utils import cdiv from vllm.utils.math_utils import cdiv
@ -362,7 +361,8 @@ class FlashInferBackend(AttentionBackend):
supports_trtllm_attention, supports_trtllm_attention,
) )
# Respect explicit disable flag (e.g., VLLM_USE_TRTLLM_ATTENTION=0) # Respect explicit disable flag (e.g.,
# --attention-config.use_trtllm_attention=0)
if force_use_trtllm_attention() is False: if force_use_trtllm_attention() is False:
return False return False
@ -500,11 +500,14 @@ class FlashInferMetadataBuilder(AttentionMetadataBuilder[FlashInferMetadata]):
self.kv_cache_dtype = self.kv_cache_spec.dtype self.kv_cache_dtype = self.kv_cache_spec.dtype
# Use model dtype as q dtype when TRTLLM attn is not supported, or # Use model dtype as q dtype when TRTLLM attn is not supported, or
# VLLM_FLASHINFER_DISABLE_Q_QUANTIZATION is set to 1. Otherwise, try to # --attention-config.disable_flashinfer_q_quantization is set to 1. Otherwise,
# use fp8 q if kv cache is fp8, and will fall back to model dtype # try to use fp8 q if kv cache is fp8, and will fall back to model dtype
# if TRTLLM attention kernel is not used when building attn metadata # if TRTLLM attention kernel is not used when building attn metadata
can_use_trtllm = can_use_trtllm_attention(self.num_qo_heads, self.num_kv_heads) can_use_trtllm = can_use_trtllm_attention(self.num_qo_heads, self.num_kv_heads)
if can_use_trtllm and not flashinfer_disable_q_quantization(): if (
can_use_trtllm
and not vllm_config.attention_config.disable_flashinfer_q_quantization
):
self.q_data_type = self.kv_cache_dtype self.q_data_type = self.kv_cache_dtype
else: else:
self.q_data_type = self.model_config.dtype self.q_data_type = self.model_config.dtype
@ -1035,6 +1038,11 @@ class FlashInferImpl(AttentionImpl):
self.sinks = sinks self.sinks = sinks
self.support_trtllm_attn = can_use_trtllm_attention(num_heads, num_kv_heads) self.support_trtllm_attn = can_use_trtllm_attention(num_heads, num_kv_heads)
vllm_config = get_current_vllm_config()
self.supports_quant_query_input = (
self.support_trtllm_attn
and not vllm_config.attention_config.disable_flashinfer_q_quantization
)
self.bmm1_scale: float | None = None self.bmm1_scale: float | None = None
self.bmm2_scale: float | None = None self.bmm2_scale: float | None = None
self.o_sf_scale: float | None = None self.o_sf_scale: float | None = None
@ -1046,12 +1054,6 @@ class FlashInferImpl(AttentionImpl):
and quant_key in (kFp8StaticTensorSym, kNvfp4Quant) and quant_key in (kFp8StaticTensorSym, kNvfp4Quant)
) )
def supports_quant_query_input(self) -> bool:
if flashinfer_disable_q_quantization():
return False
return self.support_trtllm_attn
# FlashInfer requires attention sinks to be float32 # FlashInfer requires attention sinks to be float32
def process_weights_after_loading(self, act_dtype: torch.dtype): def process_weights_after_loading(self, act_dtype: torch.dtype):
if self.sinks is not None and self.sinks.dtype != torch.float32: if self.sinks is not None and self.sinks.dtype != torch.float32:

View File

@ -438,19 +438,25 @@ A = TypeVar("A")
def use_flashinfer_prefill() -> bool: def use_flashinfer_prefill() -> bool:
# For blackwell default to flashinfer prefill if it's available since # For blackwell default to flashinfer prefill if it's available since
# it is faster than FA2. # it is faster than FA2.
from vllm.config import get_current_vllm_config
vllm_config = get_current_vllm_config()
return ( return (
not envs.VLLM_DISABLE_FLASHINFER_PREFILL not vllm_config.attention_config.disable_flashinfer_prefill
and flashinfer_available and flashinfer_available
and not envs.VLLM_USE_CUDNN_PREFILL and not vllm_config.attention_config.use_cudnn_prefill
and not envs.VLLM_USE_TRTLLM_RAGGED_DEEPSEEK_PREFILL and not vllm_config.attention_config.use_trtllm_ragged_deepseek_prefill
and current_platform.is_device_capability(100) and current_platform.is_device_capability(100)
) )
def use_cudnn_prefill() -> bool: def use_cudnn_prefill() -> bool:
from vllm.config import get_current_vllm_config
vllm_config = get_current_vllm_config()
return ( return (
flashinfer_available flashinfer_available
and envs.VLLM_USE_CUDNN_PREFILL and vllm_config.attention_config.use_cudnn_prefill
and current_platform.is_device_capability(100) and current_platform.is_device_capability(100)
and has_nvidia_artifactory() and has_nvidia_artifactory()
) )
@ -458,9 +464,12 @@ def use_cudnn_prefill() -> bool:
def use_trtllm_ragged_deepseek_prefill() -> bool: def use_trtllm_ragged_deepseek_prefill() -> bool:
"""Check if TRT-LLM ragged DeepSeek prefill should be used.""" """Check if TRT-LLM ragged DeepSeek prefill should be used."""
from vllm.config import get_current_vllm_config
vllm_config = get_current_vllm_config()
return ( return (
flashinfer_available flashinfer_available
and envs.VLLM_USE_TRTLLM_RAGGED_DEEPSEEK_PREFILL and vllm_config.attention_config.use_trtllm_ragged_deepseek_prefill
and current_platform.is_device_capability(100) and current_platform.is_device_capability(100)
) )

View File

@ -6,7 +6,6 @@ from typing import ClassVar
import torch import torch
from vllm import envs
from vllm.attention.backends.abstract import ( from vllm.attention.backends.abstract import (
AttentionLayer, AttentionLayer,
AttentionType, AttentionType,
@ -131,7 +130,9 @@ class FlashAttnMLAMetadataBuilder(MLACommonMetadataBuilder[FlashAttnMLAMetadata]
# When using cuda graph, we need to set the upper bound of the # When using cuda graph, we need to set the upper bound of the
# number of splits so that large enough intermediate buffers are # number of splits so that large enough intermediate buffers are
# pre-allocated during capture. # pre-allocated during capture.
self.max_num_splits = envs.VLLM_FLASH_ATTN_MAX_NUM_SPLITS_FOR_CUDA_GRAPH self.max_num_splits = (
vllm_config.attention_config.flash_attn_max_num_splits_for_cuda_graph
)
if vllm_is_batch_invariant(): if vllm_is_batch_invariant():
self.max_num_splits = 1 self.max_num_splits = 1

View File

@ -165,7 +165,7 @@ class RocmAttentionBackend(AttentionBackend):
raise ValueError( raise ValueError(
f"Head size {head_size} is not supported by {attn_type}. " f"Head size {head_size} is not supported by {attn_type}. "
f"Supported head sizes are: {cls.get_supported_head_sizes()}. " f"Supported head sizes are: {cls.get_supported_head_sizes()}. "
"Set VLLM_ATTENTION_BACKEND=FLEX_ATTENTION to use " "Set --attention-config.backend=FLEX_ATTENTION to use "
"FlexAttention backend which supports all head sizes." "FlexAttention backend which supports all head sizes."
) )

View File

@ -210,9 +210,6 @@ class TritonAttentionImpl(AttentionImpl):
def fused_output_quant_supported(self, quant_key: QuantKey): def fused_output_quant_supported(self, quant_key: QuantKey):
return quant_key == kFp8StaticTensorSym return quant_key == kFp8StaticTensorSym
def supports_quant_query_input(self) -> bool:
return current_platform.is_cuda()
def __init__( def __init__(
self, self,
num_heads: int, num_heads: int,
@ -262,6 +259,8 @@ class TritonAttentionImpl(AttentionImpl):
f"num_heads: {num_heads}." f"num_heads: {num_heads}."
) )
self.supports_quant_query_input = current_platform.is_cuda()
def forward( def forward(
self, self,
layer: torch.nn.Module, layer: torch.nn.Module,