mirror of
https://git.datalinker.icu/vllm-project/vllm.git
synced 2025-12-10 00:06:06 +08:00
[Attention][UX][1/N] Add AttentionConfig and change attention env vars to CLI arguments (#26315)
Signed-off-by: Matthew Bonanni <mbonanni@redhat.com> Signed-off-by: Matthew Bonanni <mbonanni001@gmail.com> Co-authored-by: gemini-code-assist[bot] <176961590+gemini-code-assist[bot]@users.noreply.github.com> Co-authored-by: Harry Mellor <19981378+hmellor@users.noreply.github.com> Co-authored-by: Lucas Wilkinson <LucasWilkinson@users.noreply.github.com>
This commit is contained in:
parent
dff0a2b394
commit
66e674cdd5
@ -12,13 +12,13 @@ from vllm._custom_ops import cutlass_scaled_fp4_mm, scaled_fp4_quant
|
||||
from vllm.attention.backends.abstract import AttentionMetadata
|
||||
from vllm.attention.backends.registry import AttentionBackendEnum
|
||||
from vllm.attention.layer import Attention
|
||||
from vllm.attention.selector import global_force_attn_backend_context_manager
|
||||
from vllm.compilation.fusion_attn import ATTN_OP, AttnFusionPass
|
||||
from vllm.compilation.fx_utils import find_op_nodes
|
||||
from vllm.compilation.matcher_utils import QUANT_OPS
|
||||
from vllm.compilation.noop_elimination import NoOpEliminationPass
|
||||
from vllm.compilation.post_cleanup import PostCleanupPass
|
||||
from vllm.config import (
|
||||
AttentionConfig,
|
||||
CacheConfig,
|
||||
CompilationConfig,
|
||||
CompilationMode,
|
||||
@ -335,6 +335,7 @@ def test_attention_quant_pattern(
|
||||
custom_ops=custom_ops_list,
|
||||
),
|
||||
cache_config=CacheConfig(cache_dtype="fp8"),
|
||||
attention_config=AttentionConfig(backend=backend),
|
||||
)
|
||||
|
||||
# Create test inputs
|
||||
@ -352,7 +353,6 @@ def test_attention_quant_pattern(
|
||||
with (
|
||||
set_current_vllm_config(vllm_config_unfused),
|
||||
set_forward_context(attn_metadata=None, vllm_config=vllm_config_unfused),
|
||||
global_force_attn_backend_context_manager(backend),
|
||||
):
|
||||
model_unfused = model_class(
|
||||
num_qo_heads=num_qo_heads,
|
||||
@ -378,7 +378,6 @@ def test_attention_quant_pattern(
|
||||
with (
|
||||
set_current_vllm_config(vllm_config),
|
||||
set_forward_context(attn_metadata=None, vllm_config=vllm_config),
|
||||
global_force_attn_backend_context_manager(backend),
|
||||
):
|
||||
model_fused = model_class(
|
||||
num_qo_heads=num_qo_heads,
|
||||
|
||||
@ -1151,13 +1151,29 @@ def test_register_kv_caches(dist_init, attn_backend, monkeypatch):
|
||||
}
|
||||
|
||||
# Store tensor info for validation
|
||||
expected_tensor_size = shared_tensor[0].element_size() * shared_tensor[0].numel()
|
||||
expected_base_addrs = [
|
||||
shared_tensor[0].data_ptr(),
|
||||
shared_tensor[1].data_ptr(),
|
||||
unique_tensor[0].data_ptr(),
|
||||
unique_tensor[1].data_ptr(),
|
||||
]
|
||||
test_shape = backend_cls.get_kv_cache_shape(
|
||||
num_blocks=1, block_size=16, num_kv_heads=1, head_size=1
|
||||
)
|
||||
is_blocks_first = len(test_shape) == 5 and test_shape[0] == 1
|
||||
|
||||
if is_blocks_first:
|
||||
expected_tensor_size = shared_tensor.element_size() * shared_tensor.numel()
|
||||
expected_base_addrs = [
|
||||
shared_tensor.data_ptr(),
|
||||
unique_tensor.data_ptr(),
|
||||
]
|
||||
expected_num_entries = 2
|
||||
else:
|
||||
expected_tensor_size = (
|
||||
shared_tensor[0].element_size() * shared_tensor[0].numel()
|
||||
)
|
||||
expected_base_addrs = [
|
||||
shared_tensor[0].data_ptr(),
|
||||
shared_tensor[1].data_ptr(),
|
||||
unique_tensor[0].data_ptr(),
|
||||
unique_tensor[1].data_ptr(),
|
||||
]
|
||||
expected_num_entries = 4
|
||||
|
||||
with (
|
||||
patch(
|
||||
@ -1192,7 +1208,7 @@ def test_register_kv_caches(dist_init, attn_backend, monkeypatch):
|
||||
# Verify get_reg_descs was called with caches_data
|
||||
assert mock_wrapper_instance.get_reg_descs.called
|
||||
caches_data, _ = mock_wrapper_instance.get_reg_descs.call_args[0]
|
||||
assert len(caches_data) == 4
|
||||
assert len(caches_data) == expected_num_entries
|
||||
|
||||
for i, cache_entry in enumerate(caches_data):
|
||||
base_addr, size, _tp_rank, _ = cache_entry
|
||||
@ -1214,7 +1230,12 @@ def test_register_kv_caches(dist_init, attn_backend, monkeypatch):
|
||||
f"Expected {expected_blocks_count} blocks, got {len(blocks_data)}"
|
||||
)
|
||||
|
||||
expected_block_len = expected_tensor_size // 2
|
||||
num_blocks = 2
|
||||
if is_blocks_first:
|
||||
expected_block_len = expected_tensor_size // num_blocks // 2
|
||||
else:
|
||||
expected_block_len = expected_tensor_size // num_blocks
|
||||
|
||||
for i, block_entry in enumerate(blocks_data):
|
||||
block_start_addr, block_len, tp_rank = block_entry
|
||||
assert block_len == expected_block_len, (
|
||||
|
||||
@ -6,8 +6,10 @@ import pytest
|
||||
import torch
|
||||
|
||||
from vllm.attention.backends.abstract import MultipleOf
|
||||
from vllm.attention.backends.registry import AttentionBackendEnum
|
||||
from vllm.attention.layer import Attention
|
||||
from vllm.config import (
|
||||
AttentionConfig,
|
||||
CacheConfig,
|
||||
ModelConfig,
|
||||
ParallelConfig,
|
||||
@ -765,7 +767,7 @@ def test_init_kv_cache_with_kv_sharing_valid():
|
||||
current_platform.is_rocm(),
|
||||
reason="Attention backend FLASHINFER is not supported on ROCm.",
|
||||
)
|
||||
def test_hybrid_attention_mamba_tensor_shapes(monkeypatch):
|
||||
def test_hybrid_attention_mamba_tensor_shapes():
|
||||
"""
|
||||
The GPU model runner creates different views into the
|
||||
KVCacheTensors for the attention and mamba layers
|
||||
@ -806,11 +808,13 @@ def test_hybrid_attention_mamba_tensor_shapes(monkeypatch):
|
||||
cache_dtype="auto",
|
||||
)
|
||||
parallel_config = ParallelConfig()
|
||||
attention_config = AttentionConfig(backend=AttentionBackendEnum.FLASHINFER)
|
||||
vllm_config = VllmConfig(
|
||||
model_config=model_config,
|
||||
cache_config=cache_config,
|
||||
scheduler_config=scheduler_config,
|
||||
parallel_config=parallel_config,
|
||||
attention_config=attention_config,
|
||||
)
|
||||
|
||||
layer_0 = "model.layers.0.self_attn.attn"
|
||||
@ -820,8 +824,7 @@ def test_hybrid_attention_mamba_tensor_shapes(monkeypatch):
|
||||
layer_4 = "model.layers.4.mixer"
|
||||
layer_5 = "model.layers.5.mixer"
|
||||
|
||||
with set_current_vllm_config(vllm_config), monkeypatch.context() as m:
|
||||
m.setenv("VLLM_ATTENTION_BACKEND", "FLASHINFER")
|
||||
with set_current_vllm_config(vllm_config):
|
||||
hf_config = vllm_config.model_config.hf_config
|
||||
fwd_context = {}
|
||||
for key in [layer_0, layer_1]:
|
||||
@ -851,10 +854,7 @@ def test_hybrid_attention_mamba_tensor_shapes(monkeypatch):
|
||||
)
|
||||
# suppress var not used error
|
||||
assert fwd_context is not None
|
||||
vllm_ctx = vllm_config.compilation_config.static_forward_context
|
||||
|
||||
with monkeypatch.context() as m:
|
||||
m.setenv("VLLM_ATTENTION_BACKEND", "FLASHINFER")
|
||||
vllm_ctx = vllm_config.compilation_config.static_forward_context
|
||||
|
||||
runner = GPUModelRunner(vllm_config, DEVICE)
|
||||
kv_cache_spec = runner.get_kv_cache_spec()
|
||||
@ -865,94 +865,94 @@ def test_hybrid_attention_mamba_tensor_shapes(monkeypatch):
|
||||
)[0]
|
||||
runner.initialize_kv_cache(kv_cache_config)
|
||||
|
||||
# random partition of blocks
|
||||
# blocks0 will be assigned to attention layers
|
||||
# blocks1 will be assigned to mamba layers
|
||||
num_blocks = kv_cache_config.num_blocks
|
||||
ind = np.arange(num_blocks)
|
||||
np.random.shuffle(ind)
|
||||
blocks0, blocks1 = ind[: (num_blocks // 2)], ind[(num_blocks // 2) :]
|
||||
# random partition of blocks
|
||||
# blocks0 will be assigned to attention layers
|
||||
# blocks1 will be assigned to mamba layers
|
||||
num_blocks = kv_cache_config.num_blocks
|
||||
ind = np.arange(num_blocks)
|
||||
np.random.shuffle(ind)
|
||||
blocks0, blocks1 = ind[: (num_blocks // 2)], ind[(num_blocks // 2) :]
|
||||
|
||||
attn_shape = vllm_ctx[layer_0].kv_cache[0].shape
|
||||
conv_shape = vllm_ctx[layer_2].kv_cache[0][0].shape
|
||||
ssm_shape = vllm_ctx[layer_2].kv_cache[0][1].shape
|
||||
attn_shape = vllm_ctx[layer_0].kv_cache[0].shape
|
||||
conv_shape = vllm_ctx[layer_2].kv_cache[0][0].shape
|
||||
ssm_shape = vllm_ctx[layer_2].kv_cache[0][1].shape
|
||||
|
||||
# assert we are using FlashInfer
|
||||
assert attn_shape[0] % num_blocks == 0
|
||||
block_split_ratio = attn_shape[0] // num_blocks
|
||||
# assert we are using FlashInfer
|
||||
assert attn_shape[0] % num_blocks == 0
|
||||
block_split_ratio = attn_shape[0] // num_blocks
|
||||
|
||||
# use small blocks for testing to avoid memory issues
|
||||
test_block_size = min(2, len(blocks0), len(blocks1))
|
||||
# use small blocks for testing to avoid memory issues
|
||||
test_block_size = min(2, len(blocks0), len(blocks1))
|
||||
|
||||
# use non-overlapping blocks to avoid data contamination
|
||||
# Split kernel blocks: first half for attention, second half for mamba
|
||||
mid_point = num_blocks // 2
|
||||
# use non-overlapping blocks to avoid data contamination
|
||||
# Split kernel blocks: first half for attention, second half for mamba
|
||||
mid_point = num_blocks // 2
|
||||
|
||||
# attention uses kernel blocks from first half (mapped to logical blocks)
|
||||
kv_blocks_for_attention = np.array([0, 1])[:test_block_size]
|
||||
# attention uses kernel blocks from first half (mapped to logical blocks)
|
||||
kv_blocks_for_attention = np.array([0, 1])[:test_block_size]
|
||||
|
||||
# mamba uses kernel blocks from second half
|
||||
kv_blocks_for_mamba = np.array([mid_point, mid_point + 1])[:test_block_size]
|
||||
# mamba uses kernel blocks from second half
|
||||
kv_blocks_for_mamba = np.array([mid_point, mid_point + 1])[:test_block_size]
|
||||
|
||||
# create small constant tensors for testing with corrected shapes
|
||||
# attention: [block_size, ...] starting from dimension 2
|
||||
attn_constant_shape = attn_shape[2:]
|
||||
conv_constant_shape = conv_shape[1:]
|
||||
ssm_constant_shape = ssm_shape[1:]
|
||||
# create small constant tensors for testing with corrected shapes
|
||||
# attention: [block_size, ...] starting from dimension 2
|
||||
attn_constant_shape = attn_shape[2:]
|
||||
conv_constant_shape = conv_shape[1:]
|
||||
ssm_constant_shape = ssm_shape[1:]
|
||||
|
||||
attn_blocks_constant = torch.full(
|
||||
(test_block_size, *attn_constant_shape), device=DEVICE, fill_value=3.33
|
||||
)
|
||||
conv_blocks_constant = torch.full(
|
||||
(test_block_size, *conv_constant_shape), device=DEVICE, fill_value=6.66
|
||||
)
|
||||
ssm_blocks_constant = torch.full(
|
||||
(test_block_size, *ssm_constant_shape), device=DEVICE, fill_value=9.99
|
||||
)
|
||||
attn_blocks_constant = torch.full(
|
||||
(test_block_size, *attn_constant_shape), device=DEVICE, fill_value=3.33
|
||||
)
|
||||
conv_blocks_constant = torch.full(
|
||||
(test_block_size, *conv_constant_shape), device=DEVICE, fill_value=6.66
|
||||
)
|
||||
ssm_blocks_constant = torch.full(
|
||||
(test_block_size, *ssm_constant_shape), device=DEVICE, fill_value=9.99
|
||||
)
|
||||
|
||||
# Fill attention blocks with constants using kv block indices
|
||||
kernel_blocks_for_attention = kv_blocks_for_attention * block_split_ratio
|
||||
# Fill attention blocks with constants using kv block indices
|
||||
kernel_blocks_for_attention = kv_blocks_for_attention * block_split_ratio
|
||||
|
||||
for layer in [layer_0, layer_1]:
|
||||
# attention: kv_cache[0][kernel_block_idx, kv_idx, ...]
|
||||
for i, kernel_block in enumerate(kernel_blocks_for_attention):
|
||||
vllm_ctx[layer].kv_cache[0][kernel_block, :] = attn_blocks_constant[i]
|
||||
for layer in [layer_0, layer_1]:
|
||||
# attention: kv_cache[0][kernel_block_idx, kv_idx, ...]
|
||||
for i, kernel_block in enumerate(kernel_blocks_for_attention):
|
||||
vllm_ctx[layer].kv_cache[0][kernel_block, :] = attn_blocks_constant[i]
|
||||
|
||||
# fill mamba blocks with constants using kernel block indices
|
||||
for layer in [layer_2, layer_3, layer_4, layer_5]:
|
||||
# mamba: kv_cache[0][component][kernel_block_idx, ...]
|
||||
for i, kv_block in enumerate(kv_blocks_for_mamba):
|
||||
vllm_ctx[layer].kv_cache[0][0][kv_block, :] = conv_blocks_constant[i]
|
||||
vllm_ctx[layer].kv_cache[0][1][kv_block, :] = ssm_blocks_constant[i]
|
||||
# fill mamba blocks with constants using kernel block indices
|
||||
for layer in [layer_2, layer_3, layer_4, layer_5]:
|
||||
# mamba: kv_cache[0][component][kernel_block_idx, ...]
|
||||
for i, kv_block in enumerate(kv_blocks_for_mamba):
|
||||
vllm_ctx[layer].kv_cache[0][0][kv_block, :] = conv_blocks_constant[i]
|
||||
vllm_ctx[layer].kv_cache[0][1][kv_block, :] = ssm_blocks_constant[i]
|
||||
|
||||
# verify attention and mamba contents are correct
|
||||
for layer in [layer_0, layer_1]:
|
||||
for i, kernel_block in enumerate(kernel_blocks_for_attention):
|
||||
actual_kv = vllm_ctx[layer].kv_cache[0][kernel_block, :]
|
||||
expected = attn_blocks_constant[i]
|
||||
# verify attention and mamba contents are correct
|
||||
for layer in [layer_0, layer_1]:
|
||||
for i, kernel_block in enumerate(kernel_blocks_for_attention):
|
||||
actual_kv = vllm_ctx[layer].kv_cache[0][kernel_block, :]
|
||||
expected = attn_blocks_constant[i]
|
||||
|
||||
# Check K and V separately
|
||||
assert torch.equal(actual_kv[0], expected)
|
||||
assert torch.equal(actual_kv[1], expected)
|
||||
# Check K and V separately
|
||||
assert torch.equal(actual_kv[0], expected)
|
||||
assert torch.equal(actual_kv[1], expected)
|
||||
|
||||
for layer in [layer_2, layer_3, layer_4, layer_5]:
|
||||
for i, kv_block in enumerate(kv_blocks_for_mamba):
|
||||
actual_conv = vllm_ctx[layer].kv_cache[0][0][kv_block, :]
|
||||
actual_ssm = vllm_ctx[layer].kv_cache[0][1][kv_block, :]
|
||||
expected_conv = conv_blocks_constant[i]
|
||||
expected_ssm = ssm_blocks_constant[i]
|
||||
for layer in [layer_2, layer_3, layer_4, layer_5]:
|
||||
for i, kv_block in enumerate(kv_blocks_for_mamba):
|
||||
actual_conv = vllm_ctx[layer].kv_cache[0][0][kv_block, :]
|
||||
actual_ssm = vllm_ctx[layer].kv_cache[0][1][kv_block, :]
|
||||
expected_conv = conv_blocks_constant[i]
|
||||
expected_ssm = ssm_blocks_constant[i]
|
||||
|
||||
assert torch.equal(actual_conv, expected_conv)
|
||||
assert torch.equal(actual_ssm, expected_ssm)
|
||||
assert torch.equal(actual_conv, expected_conv)
|
||||
assert torch.equal(actual_ssm, expected_ssm)
|
||||
|
||||
for layer in [layer_2, layer_3, layer_4, layer_5]:
|
||||
for i, kv_block in enumerate(kv_blocks_for_mamba):
|
||||
actual_conv = vllm_ctx[layer].kv_cache[0][0][kv_block, :]
|
||||
actual_ssm = vllm_ctx[layer].kv_cache[0][1][kv_block, :]
|
||||
expected_conv = conv_blocks_constant[i]
|
||||
expected_ssm = ssm_blocks_constant[i]
|
||||
assert torch.equal(actual_conv, expected_conv)
|
||||
assert torch.equal(actual_ssm, expected_ssm)
|
||||
for layer in [layer_2, layer_3, layer_4, layer_5]:
|
||||
for i, kv_block in enumerate(kv_blocks_for_mamba):
|
||||
actual_conv = vllm_ctx[layer].kv_cache[0][0][kv_block, :]
|
||||
actual_ssm = vllm_ctx[layer].kv_cache[0][1][kv_block, :]
|
||||
expected_conv = conv_blocks_constant[i]
|
||||
expected_ssm = ssm_blocks_constant[i]
|
||||
assert torch.equal(actual_conv, expected_conv)
|
||||
assert torch.equal(actual_ssm, expected_ssm)
|
||||
|
||||
|
||||
def test_hybrid_block_table_initialization():
|
||||
|
||||
@ -289,6 +289,16 @@ class AttentionImpl(ABC, Generic[T]):
|
||||
# even if they can return lse (for efficiency reasons)
|
||||
need_to_return_lse_for_decode: bool = False
|
||||
|
||||
# Whether this attention implementation supports pre-quantized query input.
|
||||
# When True, the attention layer will quantize queries before passing them
|
||||
# to this backend, allowing torch.compile to fuse the quantization with
|
||||
# previous operations. This is typically supported when using FP8 KV cache
|
||||
# with compatible attention kernels (e.g., TRT-LLM).
|
||||
# Subclasses should set this in __init__.
|
||||
# TODO add support to more backends:
|
||||
# https://github.com/vllm-project/vllm/issues/25584
|
||||
supports_quant_query_input: bool = False
|
||||
|
||||
dcp_world_size: int
|
||||
dcp_rank: int
|
||||
|
||||
@ -368,22 +378,6 @@ class AttentionImpl(ABC, Generic[T]):
|
||||
"""
|
||||
return False
|
||||
|
||||
def supports_quant_query_input(self) -> bool:
|
||||
"""
|
||||
Check if this attention implementation supports pre-quantized query input.
|
||||
|
||||
When True, the attention layer will quantize queries before passing them
|
||||
to this backend, allowing torch.compile to fuse the quantization with
|
||||
previous operations. This is typically supported when using FP8 KV cache
|
||||
with compatible attention kernels (e.g., TRT-LLM).
|
||||
TODO add support to more backends:
|
||||
https://github.com/vllm-project/vllm/issues/25584
|
||||
|
||||
Returns:
|
||||
bool: True if the implementation can accept pre-quantized queries.
|
||||
"""
|
||||
return False
|
||||
|
||||
def process_weights_after_loading(self, act_dtype: torch.dtype):
|
||||
pass
|
||||
|
||||
|
||||
@ -303,7 +303,7 @@ class Attention(nn.Module, AttentionLayerBase):
|
||||
self.query_quant = None
|
||||
if (
|
||||
self.kv_cache_dtype.startswith("fp8")
|
||||
and self.impl.supports_quant_query_input()
|
||||
and self.impl.supports_quant_query_input
|
||||
):
|
||||
self.query_quant = QuantFP8(static=True, group_shape=GroupShape.PER_TENSOR)
|
||||
|
||||
@ -338,7 +338,7 @@ class Attention(nn.Module, AttentionLayerBase):
|
||||
assert self.kv_cache_dtype in {"fp8", "fp8_e4m3"}
|
||||
|
||||
# check if query quantization is supported
|
||||
if self.impl.supports_quant_query_input():
|
||||
if self.impl.supports_quant_query_input:
|
||||
query, _ = self.query_quant(query, self._q_scale)
|
||||
|
||||
if self.use_output:
|
||||
|
||||
@ -2,19 +2,14 @@
|
||||
# SPDX-FileCopyrightText: Copyright contributors to the vLLM project
|
||||
|
||||
import inspect
|
||||
import os
|
||||
from collections.abc import Generator
|
||||
from contextlib import contextmanager
|
||||
from functools import cache
|
||||
from typing import cast, get_args
|
||||
|
||||
import torch
|
||||
|
||||
import vllm.envs as envs
|
||||
from vllm.attention.backends.abstract import AttentionBackend
|
||||
from vllm.attention.backends.registry import (
|
||||
MAMBA_TYPE_TO_BACKEND_MAP,
|
||||
AttentionBackendEnum,
|
||||
MambaAttentionBackendEnum,
|
||||
)
|
||||
from vllm.config.cache import CacheDType
|
||||
@ -24,60 +19,6 @@ from vllm.utils.import_utils import resolve_obj_by_qualname
|
||||
logger = init_logger(__name__)
|
||||
|
||||
|
||||
def get_env_variable_attn_backend() -> AttentionBackendEnum | None:
|
||||
"""
|
||||
Get the backend override specified by the vLLM attention
|
||||
backend environment variable, if one is specified.
|
||||
|
||||
Returns:
|
||||
|
||||
* AttentionBackendEnum value if an override is specified
|
||||
* None otherwise
|
||||
"""
|
||||
backend_name = os.environ.get("VLLM_ATTENTION_BACKEND")
|
||||
if backend_name is None:
|
||||
return None
|
||||
if backend_name == "XFORMERS":
|
||||
raise ValueError(
|
||||
"Attention backend 'XFORMERS' has been removed (See PR #29262 for "
|
||||
"details). Please select a supported attention backend."
|
||||
)
|
||||
return AttentionBackendEnum[backend_name]
|
||||
|
||||
|
||||
# Global state allows a particular choice of backend
|
||||
# to be forced, overriding the logic which auto-selects
|
||||
# a backend based on system & workload configuration
|
||||
# (default behavior if this variable is None)
|
||||
#
|
||||
# THIS SELECTION TAKES PRECEDENCE OVER THE
|
||||
# VLLM_ATTENTION_BACKEND ENVIRONMENT VARIABLE
|
||||
forced_attn_backend: AttentionBackendEnum | None = None
|
||||
|
||||
|
||||
def global_force_attn_backend(attn_backend: AttentionBackendEnum | None) -> None:
|
||||
"""
|
||||
Force all attention operations to use a specified backend.
|
||||
|
||||
Passing `None` for the argument re-enables automatic
|
||||
backend selection.,
|
||||
|
||||
Arguments:
|
||||
|
||||
* attn_backend: backend selection (None to revert to auto)
|
||||
"""
|
||||
global forced_attn_backend
|
||||
forced_attn_backend = attn_backend
|
||||
|
||||
|
||||
def get_global_forced_attn_backend() -> AttentionBackendEnum | None:
|
||||
"""
|
||||
Get the currently-forced choice of attention backend,
|
||||
or None if auto-selection is currently enabled.
|
||||
"""
|
||||
return forced_attn_backend
|
||||
|
||||
|
||||
def get_attn_backend(
|
||||
head_size: int,
|
||||
dtype: torch.dtype,
|
||||
@ -97,7 +38,13 @@ def get_attn_backend(
|
||||
f"Valid values are: {valid_cache_dtypes}"
|
||||
)
|
||||
|
||||
from vllm.config import get_current_vllm_config
|
||||
|
||||
vllm_config = get_current_vllm_config()
|
||||
backend_enum = vllm_config.attention_config.backend
|
||||
|
||||
return _cached_get_attn_backend(
|
||||
backend=backend_enum,
|
||||
head_size=head_size,
|
||||
dtype=dtype,
|
||||
kv_cache_dtype=cast(CacheDType | None, kv_cache_dtype),
|
||||
@ -111,6 +58,7 @@ def get_attn_backend(
|
||||
|
||||
@cache
|
||||
def _cached_get_attn_backend(
|
||||
backend,
|
||||
head_size: int,
|
||||
dtype: torch.dtype,
|
||||
kv_cache_dtype: CacheDType | None,
|
||||
@ -120,39 +68,6 @@ def _cached_get_attn_backend(
|
||||
use_sparse: bool = False,
|
||||
attn_type: str | None = None,
|
||||
) -> type[AttentionBackend]:
|
||||
# Check whether a particular choice of backend was
|
||||
# previously forced.
|
||||
#
|
||||
# THIS SELECTION OVERRIDES THE VLLM_ATTENTION_BACKEND
|
||||
# ENVIRONMENT VARIABLE.
|
||||
selected_backend = None
|
||||
backend_by_global_setting: AttentionBackendEnum | None = (
|
||||
get_global_forced_attn_backend()
|
||||
)
|
||||
if backend_by_global_setting is not None:
|
||||
selected_backend = backend_by_global_setting
|
||||
else:
|
||||
# Check the environment variable and override if specified
|
||||
backend_by_env_var: str | None = envs.VLLM_ATTENTION_BACKEND
|
||||
if backend_by_env_var is not None:
|
||||
if backend_by_env_var.endswith("_VLLM_V1"):
|
||||
logger.warning(
|
||||
"The suffix '_VLLM_V1' in the environment variable "
|
||||
"VLLM_ATTENTION_BACKEND is no longer necessary as "
|
||||
"V0 backends have been deprecated. "
|
||||
"Please remove this suffix from your "
|
||||
"environment variable setting.",
|
||||
)
|
||||
backend_by_env_var = backend_by_env_var.removesuffix("_VLLM_V1")
|
||||
try:
|
||||
selected_backend = AttentionBackendEnum[backend_by_env_var]
|
||||
except KeyError as e:
|
||||
raise ValueError(
|
||||
f"Invalid attention backend: '{backend_by_env_var}'. Valid "
|
||||
f"backends are: {list(AttentionBackendEnum.__members__.keys())}"
|
||||
) from e
|
||||
|
||||
# get device-specific attn_backend
|
||||
from vllm.platforms import current_platform
|
||||
|
||||
sig = inspect.signature(current_platform.get_attn_backend_cls)
|
||||
@ -163,7 +78,7 @@ def _cached_get_attn_backend(
|
||||
"remove it from your plugin code."
|
||||
)
|
||||
attention_cls = current_platform.get_attn_backend_cls(
|
||||
selected_backend,
|
||||
backend,
|
||||
head_size,
|
||||
dtype,
|
||||
kv_cache_dtype,
|
||||
@ -176,7 +91,7 @@ def _cached_get_attn_backend(
|
||||
)
|
||||
else:
|
||||
attention_cls = current_platform.get_attn_backend_cls(
|
||||
selected_backend,
|
||||
backend,
|
||||
head_size,
|
||||
dtype,
|
||||
kv_cache_dtype,
|
||||
@ -232,37 +147,3 @@ def _cached_get_mamba_attn_backend(
|
||||
|
||||
mamba_attn_backend = selected_backend.get_class()
|
||||
return mamba_attn_backend
|
||||
|
||||
|
||||
@contextmanager
|
||||
def global_force_attn_backend_context_manager(
|
||||
attn_backend: AttentionBackendEnum,
|
||||
) -> Generator[None, None, None]:
|
||||
"""
|
||||
Globally force a vLLM attention backend override within a
|
||||
context manager, reverting the global attention backend
|
||||
override to its prior state upon exiting the context
|
||||
manager.
|
||||
|
||||
Arguments:
|
||||
|
||||
* attn_backend: attention backend to force
|
||||
|
||||
Returns:
|
||||
|
||||
* Generator
|
||||
"""
|
||||
|
||||
# Save the current state of the global backend override (if any)
|
||||
original_value = get_global_forced_attn_backend()
|
||||
|
||||
# Globally force the new backend override
|
||||
global_force_attn_backend(attn_backend)
|
||||
|
||||
# Yield control back to the enclosed code block
|
||||
try:
|
||||
yield
|
||||
finally:
|
||||
# Revert the original global backend override, if any
|
||||
global_force_attn_backend(original_value)
|
||||
_cached_get_attn_backend.cache_clear()
|
||||
|
||||
@ -1,7 +1,6 @@
|
||||
# SPDX-License-Identifier: Apache-2.0
|
||||
# SPDX-FileCopyrightText: Copyright contributors to the vLLM project
|
||||
|
||||
from vllm import envs
|
||||
from vllm.logger import init_logger
|
||||
from vllm.platforms import current_platform
|
||||
|
||||
@ -49,10 +48,12 @@ def get_flash_attn_version(requires_alibi: bool = False) -> int | None:
|
||||
3 if (device_capability.major == 9 and is_fa_version_supported(3)) else 2
|
||||
)
|
||||
|
||||
# 2. override if passed by environment
|
||||
if envs.VLLM_FLASH_ATTN_VERSION is not None:
|
||||
assert envs.VLLM_FLASH_ATTN_VERSION in [2, 3]
|
||||
fa_version = envs.VLLM_FLASH_ATTN_VERSION
|
||||
# 2. override if passed by environment or config
|
||||
from vllm.config import get_current_vllm_config
|
||||
|
||||
vllm_config = get_current_vllm_config()
|
||||
if vllm_config.attention_config.flash_attn_version is not None:
|
||||
fa_version = vllm_config.attention_config.flash_attn_version
|
||||
|
||||
# 3. fallback for unsupported combinations
|
||||
if device_capability.major == 10 and fa_version == 3:
|
||||
|
||||
@ -1,6 +1,7 @@
|
||||
# SPDX-License-Identifier: Apache-2.0
|
||||
# SPDX-FileCopyrightText: Copyright contributors to the vLLM project
|
||||
|
||||
from vllm.config.attention import AttentionConfig
|
||||
from vllm.config.cache import CacheConfig
|
||||
from vllm.config.compilation import (
|
||||
CompilationConfig,
|
||||
@ -46,6 +47,8 @@ from vllm.config.vllm import (
|
||||
# __all__ should only contain classes and functions.
|
||||
# Types and globals should be imported from their respective modules.
|
||||
__all__ = [
|
||||
# From vllm.config.attention
|
||||
"AttentionConfig",
|
||||
# From vllm.config.cache
|
||||
"CacheConfig",
|
||||
# From vllm.config.compilation
|
||||
|
||||
114
vllm/config/attention.py
Normal file
114
vllm/config/attention.py
Normal file
@ -0,0 +1,114 @@
|
||||
# SPDX-License-Identifier: Apache-2.0
|
||||
# SPDX-FileCopyrightText: Copyright contributors to the vLLM project
|
||||
|
||||
from typing import Any, Literal
|
||||
|
||||
from pydantic import field_validator
|
||||
from pydantic.dataclasses import dataclass
|
||||
|
||||
from vllm.attention.backends.registry import AttentionBackendEnum
|
||||
from vllm.config.utils import config
|
||||
from vllm.logger import init_logger
|
||||
|
||||
logger = init_logger(__name__)
|
||||
|
||||
|
||||
@config
|
||||
@dataclass
|
||||
class AttentionConfig:
|
||||
"""Configuration for attention mechanisms in vLLM."""
|
||||
|
||||
backend: AttentionBackendEnum | None = None
|
||||
"""Attention backend to use. If None, will be selected automatically."""
|
||||
|
||||
flash_attn_version: Literal[2, 3] | None = None
|
||||
"""Force vllm to use a specific flash-attention version (2 or 3).
|
||||
Only valid when using the flash-attention backend."""
|
||||
|
||||
use_prefill_decode_attention: bool = False
|
||||
"""Use separate prefill and decode kernels for attention instead of
|
||||
the unified triton kernel."""
|
||||
|
||||
flash_attn_max_num_splits_for_cuda_graph: int = 32
|
||||
"""Flash Attention max number splits for cuda graph decode."""
|
||||
|
||||
use_cudnn_prefill: bool = False
|
||||
"""Whether to use cudnn prefill."""
|
||||
|
||||
use_trtllm_ragged_deepseek_prefill: bool = False
|
||||
"""Whether to use TRTLLM ragged deepseek prefill."""
|
||||
|
||||
use_trtllm_attention: bool | None = None
|
||||
"""If set to True/False, use or don't use the TRTLLM attention backend
|
||||
in flashinfer. If None, auto-detect the attention backend in flashinfer."""
|
||||
|
||||
disable_flashinfer_prefill: bool = False
|
||||
"""Whether to disable flashinfer prefill."""
|
||||
|
||||
disable_flashinfer_q_quantization: bool = False
|
||||
"""If set, when using fp8 kv, do not quantize Q to fp8."""
|
||||
|
||||
def compute_hash(self) -> str:
|
||||
"""
|
||||
Provide a hash that uniquely identifies all the configs
|
||||
that affect the structure of the computation
|
||||
graph from input ids/embeddings to the final hidden states,
|
||||
excluding anything before input ids/embeddings and after
|
||||
the final hidden states.
|
||||
"""
|
||||
from vllm.config.utils import get_hash_factors, hash_factors
|
||||
|
||||
ignored_factors: list[str] = []
|
||||
factors = get_hash_factors(self, ignored_factors)
|
||||
return hash_factors(factors)
|
||||
|
||||
@field_validator("backend", mode="before")
|
||||
@classmethod
|
||||
def validate_backend_before(cls, value: Any) -> Any:
|
||||
"""Enable parsing of the `backend` enum type from string."""
|
||||
if isinstance(value, str):
|
||||
return AttentionBackendEnum[value.upper()]
|
||||
return value
|
||||
|
||||
def _set_from_env_if_set(self, field_name: str, env_var_name: str) -> None:
|
||||
"""Set field from env var if set, with deprecation warning."""
|
||||
from vllm import envs
|
||||
|
||||
if envs.is_set(env_var_name):
|
||||
value = getattr(envs, env_var_name)
|
||||
if field_name == "backend":
|
||||
value = self.validate_backend_before(value)
|
||||
setattr(self, field_name, value)
|
||||
logger.warning_once(
|
||||
"Using %s environment variable is deprecated and will be removed in "
|
||||
"v0.14.0 or v1.0.0, whichever is soonest. Please use "
|
||||
"--attention-config.%s command line argument or "
|
||||
"AttentionConfig(%s=...) config field instead.",
|
||||
env_var_name,
|
||||
field_name,
|
||||
field_name,
|
||||
)
|
||||
|
||||
def __post_init__(self) -> None:
|
||||
self._set_from_env_if_set("backend", "VLLM_ATTENTION_BACKEND")
|
||||
self._set_from_env_if_set("flash_attn_version", "VLLM_FLASH_ATTN_VERSION")
|
||||
self._set_from_env_if_set(
|
||||
"use_prefill_decode_attention", "VLLM_V1_USE_PREFILL_DECODE_ATTENTION"
|
||||
)
|
||||
self._set_from_env_if_set(
|
||||
"flash_attn_max_num_splits_for_cuda_graph",
|
||||
"VLLM_FLASH_ATTN_MAX_NUM_SPLITS_FOR_CUDA_GRAPH",
|
||||
)
|
||||
self._set_from_env_if_set("use_cudnn_prefill", "VLLM_USE_CUDNN_PREFILL")
|
||||
self._set_from_env_if_set(
|
||||
"use_trtllm_ragged_deepseek_prefill",
|
||||
"VLLM_USE_TRTLLM_RAGGED_DEEPSEEK_PREFILL",
|
||||
)
|
||||
self._set_from_env_if_set("use_trtllm_attention", "VLLM_USE_TRTLLM_ATTENTION")
|
||||
self._set_from_env_if_set(
|
||||
"disable_flashinfer_prefill", "VLLM_DISABLE_FLASHINFER_PREFILL"
|
||||
)
|
||||
self._set_from_env_if_set(
|
||||
"disable_flashinfer_q_quantization",
|
||||
"VLLM_FLASHINFER_DISABLE_Q_QUANTIZATION",
|
||||
)
|
||||
@ -4,7 +4,6 @@
|
||||
import warnings
|
||||
from collections.abc import Callable
|
||||
from dataclasses import InitVar, field
|
||||
from importlib.util import find_spec
|
||||
from typing import TYPE_CHECKING, Any, Literal, cast, get_args
|
||||
|
||||
import torch
|
||||
@ -467,18 +466,6 @@ class ModelConfig:
|
||||
|
||||
self.maybe_pull_model_tokenizer_for_runai(self.model, self.tokenizer)
|
||||
|
||||
if (
|
||||
(backend := envs.VLLM_ATTENTION_BACKEND)
|
||||
and backend == "FLASHINFER"
|
||||
and find_spec("flashinfer") is None
|
||||
):
|
||||
raise ValueError(
|
||||
"VLLM_ATTENTION_BACKEND is set to FLASHINFER, but flashinfer "
|
||||
"module was not found. See "
|
||||
"https://github.com/vllm-project/vllm/blob/main/docker/Dockerfile " # noqa: E501
|
||||
"for instructions on how to install it."
|
||||
)
|
||||
|
||||
from vllm.platforms import current_platform
|
||||
|
||||
if self.override_attention_dtype is not None and not current_platform.is_rocm():
|
||||
|
||||
@ -27,6 +27,7 @@ from vllm.transformers_utils.runai_utils import is_runai_obj_uri
|
||||
from vllm.utils import random_uuid
|
||||
from vllm.utils.hashing import safe_hash
|
||||
|
||||
from .attention import AttentionConfig
|
||||
from .cache import CacheConfig
|
||||
from .compilation import CompilationConfig, CompilationMode, CUDAGraphMode
|
||||
from .device import DeviceConfig
|
||||
@ -192,6 +193,8 @@ class VllmConfig:
|
||||
"""Device configuration."""
|
||||
load_config: LoadConfig = Field(default_factory=LoadConfig)
|
||||
"""Load configuration."""
|
||||
attention_config: AttentionConfig = Field(default_factory=AttentionConfig)
|
||||
"""Attention configuration."""
|
||||
lora_config: LoRAConfig | None = None
|
||||
"""LoRA configuration."""
|
||||
speculative_config: SpeculativeConfig | None = None
|
||||
@ -279,6 +282,10 @@ class VllmConfig:
|
||||
vllm_factors.append(self.load_config.compute_hash())
|
||||
else:
|
||||
vllm_factors.append("None")
|
||||
if self.attention_config:
|
||||
vllm_factors.append(self.attention_config.compute_hash())
|
||||
else:
|
||||
vllm_factors.append("None")
|
||||
if self.lora_config:
|
||||
vllm_factors.append(self.lora_config.compute_hash())
|
||||
else:
|
||||
|
||||
@ -34,6 +34,7 @@ from typing_extensions import TypeIs
|
||||
import vllm.envs as envs
|
||||
from vllm.attention.backends.registry import AttentionBackendEnum
|
||||
from vllm.config import (
|
||||
AttentionConfig,
|
||||
CacheConfig,
|
||||
CompilationConfig,
|
||||
ConfigType,
|
||||
@ -527,6 +528,7 @@ class EngineArgs:
|
||||
|
||||
pooler_config: PoolerConfig | None = ModelConfig.pooler_config
|
||||
compilation_config: CompilationConfig = get_field(VllmConfig, "compilation_config")
|
||||
attention_config: AttentionConfig = get_field(VllmConfig, "attention_config")
|
||||
worker_cls: str = ParallelConfig.worker_cls
|
||||
worker_extension_cls: str = ParallelConfig.worker_extension_cls
|
||||
|
||||
@ -542,6 +544,7 @@ class EngineArgs:
|
||||
)
|
||||
model_impl: str = ModelConfig.model_impl
|
||||
override_attention_dtype: str = ModelConfig.override_attention_dtype
|
||||
attention_backend: AttentionBackendEnum | None = AttentionConfig.backend
|
||||
|
||||
calculate_kv_scales: bool = CacheConfig.calculate_kv_scales
|
||||
mamba_cache_dtype: MambaDType = CacheConfig.mamba_cache_dtype
|
||||
@ -580,6 +583,8 @@ class EngineArgs:
|
||||
# CompilationConfig object
|
||||
if isinstance(self.compilation_config, dict):
|
||||
self.compilation_config = CompilationConfig(**self.compilation_config)
|
||||
if isinstance(self.attention_config, dict):
|
||||
self.attention_config = AttentionConfig(**self.attention_config)
|
||||
if isinstance(self.eplb_config, dict):
|
||||
self.eplb_config = EPLBConfig(**self.eplb_config)
|
||||
# Setup plugins
|
||||
@ -717,6 +722,16 @@ class EngineArgs:
|
||||
"--pt-load-map-location", **load_kwargs["pt_load_map_location"]
|
||||
)
|
||||
|
||||
# Attention arguments
|
||||
attention_kwargs = get_kwargs(AttentionConfig)
|
||||
attention_group = parser.add_argument_group(
|
||||
title="AttentionConfig",
|
||||
description=AttentionConfig.__doc__,
|
||||
)
|
||||
attention_group.add_argument(
|
||||
"--attention-backend", **attention_kwargs["backend"]
|
||||
)
|
||||
|
||||
# Structured outputs arguments
|
||||
structured_outputs_kwargs = get_kwargs(StructuredOutputsConfig)
|
||||
structured_outputs_group = parser.add_argument_group(
|
||||
@ -1140,6 +1155,9 @@ class EngineArgs:
|
||||
vllm_group.add_argument(
|
||||
"--compilation-config", "-cc", **vllm_kwargs["compilation_config"]
|
||||
)
|
||||
vllm_group.add_argument(
|
||||
"--attention-config", "-ac", **vllm_kwargs["attention_config"]
|
||||
)
|
||||
vllm_group.add_argument(
|
||||
"--additional-config", **vllm_kwargs["additional_config"]
|
||||
)
|
||||
@ -1693,6 +1711,16 @@ class EngineArgs:
|
||||
if model_config.quantization == "bitsandbytes":
|
||||
self.quantization = self.load_format = "bitsandbytes"
|
||||
|
||||
# Attention config overrides
|
||||
attention_config = copy.deepcopy(self.attention_config)
|
||||
if self.attention_backend is not None:
|
||||
if attention_config.backend is not None:
|
||||
raise ValueError(
|
||||
"attention_backend and attention_config.backend "
|
||||
"are mutually exclusive"
|
||||
)
|
||||
attention_config.backend = self.attention_backend
|
||||
|
||||
load_config = self.create_load_config()
|
||||
|
||||
# Pass reasoning_parser into StructuredOutputsConfig
|
||||
@ -1750,9 +1778,10 @@ class EngineArgs:
|
||||
parallel_config=parallel_config,
|
||||
scheduler_config=scheduler_config,
|
||||
device_config=device_config,
|
||||
load_config=load_config,
|
||||
attention_config=attention_config,
|
||||
lora_config=lora_config,
|
||||
speculative_config=speculative_config,
|
||||
load_config=load_config,
|
||||
structured_outputs_config=self.structured_outputs_config,
|
||||
observability_config=observability_config,
|
||||
compilation_config=compilation_config,
|
||||
|
||||
@ -4,7 +4,7 @@ from copy import deepcopy
|
||||
from math import lcm
|
||||
from typing import TYPE_CHECKING
|
||||
|
||||
import vllm.envs as envs
|
||||
from vllm.attention.backends.registry import AttentionBackendEnum
|
||||
from vllm.logger import init_logger
|
||||
from vllm.model_executor.models import ModelRegistry
|
||||
from vllm.platforms import current_platform
|
||||
@ -331,6 +331,7 @@ class HybridAttentionMambaModelConfig(VerifyAndUpdateConfig):
|
||||
# Enable FULL_AND_PIECEWISE by default
|
||||
MambaModelConfig.verify_and_update_config(vllm_config)
|
||||
|
||||
attention_config = vllm_config.attention_config
|
||||
cache_config = vllm_config.cache_config
|
||||
model_config = vllm_config.model_config
|
||||
parallel_config = vllm_config.parallel_config
|
||||
@ -347,7 +348,9 @@ class HybridAttentionMambaModelConfig(VerifyAndUpdateConfig):
|
||||
# * CUTLASS_MLA backend: kernel_block_size 128 alignment
|
||||
# * Other MLA backends: kernel_block_size 64 alignment
|
||||
if model_config.use_mla:
|
||||
use_cutlass_mla = envs.VLLM_ATTENTION_BACKEND == "CUTLASS_MLA"
|
||||
use_cutlass_mla = (
|
||||
attention_config.backend == AttentionBackendEnum.CUTLASS_MLA
|
||||
)
|
||||
kernel_block_alignment_size = 128 if use_cutlass_mla else 64
|
||||
attn_page_size_1_token = MLAAttentionSpec(
|
||||
block_size=1,
|
||||
@ -361,8 +364,8 @@ class HybridAttentionMambaModelConfig(VerifyAndUpdateConfig):
|
||||
current_platform.is_device_capability(100)
|
||||
and model_config.get_head_size() == 256
|
||||
and (
|
||||
envs.VLLM_ATTENTION_BACKEND is None
|
||||
or envs.VLLM_ATTENTION_BACKEND == "FLASHINFER"
|
||||
attention_config.backend is None
|
||||
or attention_config.backend == AttentionBackendEnum.FLASHINFER
|
||||
)
|
||||
):
|
||||
# https://github.com/flashinfer-ai/flashinfer/issues/1993 reports that`
|
||||
|
||||
@ -11,7 +11,7 @@ import torch
|
||||
from transformers import PretrainedConfig
|
||||
|
||||
from vllm.attention.backends.registry import AttentionBackendEnum
|
||||
from vllm.config import VllmConfig
|
||||
from vllm.config import VllmConfig, get_current_vllm_config
|
||||
from vllm.distributed import (
|
||||
get_tensor_model_parallel_rank,
|
||||
get_tensor_model_parallel_world_size,
|
||||
@ -91,10 +91,7 @@ def get_vit_attn_backend(
|
||||
if attn_backend_override is not None:
|
||||
return attn_backend_override
|
||||
|
||||
# Lazy import to avoid circular dependency
|
||||
from vllm.attention.selector import get_env_variable_attn_backend
|
||||
|
||||
selected_backend: AttentionBackendEnum | None = get_env_variable_attn_backend()
|
||||
selected_backend = get_current_vllm_config().attention_config.backend
|
||||
if selected_backend is not None:
|
||||
return selected_backend
|
||||
|
||||
|
||||
@ -14,7 +14,6 @@ from typing_extensions import ParamSpec
|
||||
|
||||
# import custom ops, trigger op registration
|
||||
import vllm._C # noqa
|
||||
import vllm.envs as envs
|
||||
from vllm.attention.backends.abstract import AttentionType
|
||||
from vllm.attention.backends.registry import AttentionBackendEnum
|
||||
from vllm.logger import init_logger
|
||||
@ -149,6 +148,8 @@ class CudaPlatformBase(Platform):
|
||||
|
||||
@classmethod
|
||||
def check_and_update_config(cls, vllm_config: "VllmConfig") -> None:
|
||||
from vllm.attention.backends.registry import AttentionBackendEnum
|
||||
|
||||
parallel_config = vllm_config.parallel_config
|
||||
model_config = vllm_config.model_config
|
||||
|
||||
@ -171,7 +172,7 @@ class CudaPlatformBase(Platform):
|
||||
and cache_config.block_size is not None
|
||||
):
|
||||
use_sparse = hasattr(vllm_config.model_config.hf_config, "index_topk")
|
||||
# If `VLLM_ATTENTION_BACKEND` is not set and we are using MLA,
|
||||
# If `--attention-config.backend` is not set and we are using MLA,
|
||||
# then we default to FlashMLA backend for non-blackwell GPUs,
|
||||
# else we default to CutlassMLA. For each case, we force the
|
||||
# required block_size.
|
||||
@ -179,23 +180,25 @@ class CudaPlatformBase(Platform):
|
||||
use_cutlass_mla = False
|
||||
use_flashinfer_mla = False
|
||||
|
||||
if envs.VLLM_ATTENTION_BACKEND is None:
|
||||
if vllm_config.attention_config.backend is None:
|
||||
# Default case
|
||||
if cls.is_device_capability(100):
|
||||
# Blackwell => Force CutlassMLA.
|
||||
use_cutlass_mla = True
|
||||
# TODO: This does not work, because the
|
||||
# global_force_attn_backend_context_manager is not set.
|
||||
# See vllm/attention/selector.py:_cached_get_attn_backend
|
||||
envs.VLLM_ATTENTION_BACKEND = "CUTLASS_MLA"
|
||||
# Set the backend in AttentionConfig so it's used during
|
||||
# backend selection
|
||||
vllm_config.attention_config.backend = (
|
||||
AttentionBackendEnum.CUTLASS_MLA
|
||||
)
|
||||
else:
|
||||
# Not Blackwell
|
||||
use_flashmla = True
|
||||
else:
|
||||
# Forced case
|
||||
use_flashmla = envs.VLLM_ATTENTION_BACKEND == "FLASHMLA"
|
||||
use_cutlass_mla = envs.VLLM_ATTENTION_BACKEND == "CUTLASS_MLA"
|
||||
use_flashinfer_mla = envs.VLLM_ATTENTION_BACKEND == "FLASHINFER_MLA"
|
||||
backend = vllm_config.attention_config.backend
|
||||
use_flashmla = backend == AttentionBackendEnum.FLASHMLA
|
||||
use_cutlass_mla = backend == AttentionBackendEnum.CUTLASS_MLA
|
||||
use_flashinfer_mla = backend == AttentionBackendEnum.FLASHINFER_MLA
|
||||
|
||||
from vllm.attention.ops.flashmla import is_flashmla_dense_supported
|
||||
|
||||
|
||||
@ -267,21 +267,16 @@ def supports_trtllm_attention() -> bool:
|
||||
return current_platform.is_device_capability(100) and has_nvidia_artifactory()
|
||||
|
||||
|
||||
@functools.cache
|
||||
def _force_use_trtllm_attention(env_value: bool | None) -> bool | None:
|
||||
"""Cache the env value for VLLM_USE_TRTLLM_ATTENTION"""
|
||||
if env_value is not None:
|
||||
logger.info_once("VLLM_USE_TRTLLM_ATTENTION is set to %s", env_value)
|
||||
return env_value
|
||||
|
||||
|
||||
def force_use_trtllm_attention() -> bool | None:
|
||||
"""
|
||||
Return `None` if VLLM_USE_TRTLLM_ATTENTION is not set,
|
||||
Return `None` if --attention-config.use_trtllm_attention is not set,
|
||||
return `True` if TRTLLM attention is forced to be used,
|
||||
return `False` if TRTLLM attention is forced to be not used.
|
||||
"""
|
||||
return _force_use_trtllm_attention(envs.VLLM_USE_TRTLLM_ATTENTION)
|
||||
from vllm.config import get_current_vllm_config
|
||||
|
||||
vllm_config = get_current_vllm_config()
|
||||
return vllm_config.attention_config.use_trtllm_attention
|
||||
|
||||
|
||||
def can_use_trtllm_attention(num_qo_heads: int, num_kv_heads: int) -> bool:
|
||||
@ -307,7 +302,7 @@ def use_trtllm_attention(
|
||||
"""Return `True` if TRTLLM attention is used."""
|
||||
force_use_trtllm = force_use_trtllm_attention()
|
||||
|
||||
# Environment variable is set to 0 - respect it
|
||||
# CLI argument is set to 0 - respect it
|
||||
if force_use_trtllm is not None and not force_use_trtllm:
|
||||
return False
|
||||
|
||||
@ -324,7 +319,7 @@ def use_trtllm_attention(
|
||||
if force_use_trtllm:
|
||||
logger.warning_once(
|
||||
"TRTLLM attention is not supported on this platform, "
|
||||
"but VLLM_USE_TRTLLM_ATTENTION is set to 1"
|
||||
"but --attention-config.use_trtllm_attention is set to 1"
|
||||
)
|
||||
return False
|
||||
|
||||
@ -333,7 +328,8 @@ def use_trtllm_attention(
|
||||
if force_use_trtllm:
|
||||
logger.warning_once(
|
||||
"TRTLLM attention is not supported for this combination of "
|
||||
"query and key heads, but VLLM_USE_TRTLLM_ATTENTION is set to 1"
|
||||
"query and key heads, but --attention-config.use_trtllm_attention is "
|
||||
"set to 1"
|
||||
)
|
||||
return False
|
||||
|
||||
@ -354,7 +350,7 @@ def use_trtllm_attention(
|
||||
return True
|
||||
|
||||
if force_use_trtllm is None:
|
||||
# Environment variable not set - use auto-detection
|
||||
# CLI argument not set - use auto-detection
|
||||
if is_prefill:
|
||||
# Prefill auto-detection
|
||||
use_trtllm = kv_cache_dtype == "auto"
|
||||
@ -367,8 +363,10 @@ def use_trtllm_attention(
|
||||
logger.warning_once("Using TRTLLM decode attention (auto-detected).")
|
||||
return use_trtllm
|
||||
|
||||
# Environment variable is set to 1 - respect it
|
||||
logger.info_once("Using TRTLLM attention (VLLM_USE_TRTLLM_ATTENTION is set to 1)")
|
||||
# CLI argument is set to 1 - respect it
|
||||
logger.info_once(
|
||||
"Using TRTLLM attention (--attention-config.use_trtllm_attention is set to 1)"
|
||||
)
|
||||
return True
|
||||
|
||||
|
||||
@ -500,12 +498,6 @@ def flashinfer_scaled_fp8_mm(
|
||||
return output
|
||||
|
||||
|
||||
@functools.cache
|
||||
def flashinfer_disable_q_quantization() -> bool:
|
||||
"""Cache result which only depends on the environment"""
|
||||
return envs.VLLM_FLASHINFER_DISABLE_Q_QUANTIZATION
|
||||
|
||||
|
||||
__all__ = [
|
||||
"has_flashinfer",
|
||||
"flashinfer_trtllm_fp8_block_scale_moe",
|
||||
@ -526,7 +518,6 @@ __all__ = [
|
||||
"supports_trtllm_attention",
|
||||
"can_use_trtllm_attention",
|
||||
"use_trtllm_attention",
|
||||
"flashinfer_disable_q_quantization",
|
||||
"flashinfer_scaled_fp4_mm",
|
||||
"flashinfer_scaled_fp8_mm",
|
||||
]
|
||||
|
||||
@ -8,7 +8,6 @@ from typing import ClassVar
|
||||
import numpy as np
|
||||
import torch
|
||||
|
||||
from vllm import envs
|
||||
from vllm.attention.backends.abstract import (
|
||||
AttentionBackend,
|
||||
AttentionImpl,
|
||||
@ -264,6 +263,7 @@ class FlashAttentionMetadataBuilder(AttentionMetadataBuilder[FlashAttentionMetad
|
||||
self.parallel_config = vllm_config.parallel_config
|
||||
self.cache_config = vllm_config.cache_config
|
||||
self.compilation_config = vllm_config.compilation_config
|
||||
self.attention_config = vllm_config.attention_config
|
||||
|
||||
self.num_heads_q = self.model_config.get_num_attention_heads(
|
||||
self.parallel_config
|
||||
@ -304,7 +304,9 @@ class FlashAttentionMetadataBuilder(AttentionMetadataBuilder[FlashAttentionMetad
|
||||
# When using cuda graph, we need to set the upper bound of the
|
||||
# number of splits so that large enough intermediate buffers are
|
||||
# pre-allocated during capture.
|
||||
self.max_num_splits = envs.VLLM_FLASH_ATTN_MAX_NUM_SPLITS_FOR_CUDA_GRAPH
|
||||
self.max_num_splits = (
|
||||
self.attention_config.flash_attn_max_num_splits_for_cuda_graph
|
||||
)
|
||||
|
||||
# Sliding window size to be used with the AOT scheduler will be
|
||||
# populated on first build() call.
|
||||
@ -554,8 +556,7 @@ class FlashAttentionImpl(AttentionImpl):
|
||||
"heads in the layer"
|
||||
)
|
||||
|
||||
def supports_quant_query_input(self) -> bool:
|
||||
return True
|
||||
self.supports_quant_query_input = True
|
||||
|
||||
def forward(
|
||||
self,
|
||||
|
||||
@ -26,7 +26,7 @@ from vllm.attention.backends.abstract import (
|
||||
)
|
||||
from vllm.attention.ops.common import cp_lse_ag_out_rs
|
||||
from vllm.attention.ops.merge_attn_states import merge_attn_states
|
||||
from vllm.config import CUDAGraphMode, VllmConfig
|
||||
from vllm.config import CUDAGraphMode, VllmConfig, get_current_vllm_config
|
||||
from vllm.config.cache import CacheDType
|
||||
from vllm.distributed.parallel_state import get_dcp_group
|
||||
from vllm.logger import init_logger
|
||||
@ -43,7 +43,6 @@ from vllm.platforms.interface import DeviceCapability
|
||||
from vllm.triton_utils import tl, triton
|
||||
from vllm.utils.flashinfer import (
|
||||
can_use_trtllm_attention,
|
||||
flashinfer_disable_q_quantization,
|
||||
use_trtllm_attention,
|
||||
)
|
||||
from vllm.utils.math_utils import cdiv
|
||||
@ -362,7 +361,8 @@ class FlashInferBackend(AttentionBackend):
|
||||
supports_trtllm_attention,
|
||||
)
|
||||
|
||||
# Respect explicit disable flag (e.g., VLLM_USE_TRTLLM_ATTENTION=0)
|
||||
# Respect explicit disable flag (e.g.,
|
||||
# --attention-config.use_trtllm_attention=0)
|
||||
if force_use_trtllm_attention() is False:
|
||||
return False
|
||||
|
||||
@ -500,11 +500,14 @@ class FlashInferMetadataBuilder(AttentionMetadataBuilder[FlashInferMetadata]):
|
||||
self.kv_cache_dtype = self.kv_cache_spec.dtype
|
||||
|
||||
# Use model dtype as q dtype when TRTLLM attn is not supported, or
|
||||
# VLLM_FLASHINFER_DISABLE_Q_QUANTIZATION is set to 1. Otherwise, try to
|
||||
# use fp8 q if kv cache is fp8, and will fall back to model dtype
|
||||
# --attention-config.disable_flashinfer_q_quantization is set to 1. Otherwise,
|
||||
# try to use fp8 q if kv cache is fp8, and will fall back to model dtype
|
||||
# if TRTLLM attention kernel is not used when building attn metadata
|
||||
can_use_trtllm = can_use_trtllm_attention(self.num_qo_heads, self.num_kv_heads)
|
||||
if can_use_trtllm and not flashinfer_disable_q_quantization():
|
||||
if (
|
||||
can_use_trtllm
|
||||
and not vllm_config.attention_config.disable_flashinfer_q_quantization
|
||||
):
|
||||
self.q_data_type = self.kv_cache_dtype
|
||||
else:
|
||||
self.q_data_type = self.model_config.dtype
|
||||
@ -1035,6 +1038,11 @@ class FlashInferImpl(AttentionImpl):
|
||||
self.sinks = sinks
|
||||
|
||||
self.support_trtllm_attn = can_use_trtllm_attention(num_heads, num_kv_heads)
|
||||
vllm_config = get_current_vllm_config()
|
||||
self.supports_quant_query_input = (
|
||||
self.support_trtllm_attn
|
||||
and not vllm_config.attention_config.disable_flashinfer_q_quantization
|
||||
)
|
||||
self.bmm1_scale: float | None = None
|
||||
self.bmm2_scale: float | None = None
|
||||
self.o_sf_scale: float | None = None
|
||||
@ -1046,12 +1054,6 @@ class FlashInferImpl(AttentionImpl):
|
||||
and quant_key in (kFp8StaticTensorSym, kNvfp4Quant)
|
||||
)
|
||||
|
||||
def supports_quant_query_input(self) -> bool:
|
||||
if flashinfer_disable_q_quantization():
|
||||
return False
|
||||
|
||||
return self.support_trtllm_attn
|
||||
|
||||
# FlashInfer requires attention sinks to be float32
|
||||
def process_weights_after_loading(self, act_dtype: torch.dtype):
|
||||
if self.sinks is not None and self.sinks.dtype != torch.float32:
|
||||
|
||||
@ -438,19 +438,25 @@ A = TypeVar("A")
|
||||
def use_flashinfer_prefill() -> bool:
|
||||
# For blackwell default to flashinfer prefill if it's available since
|
||||
# it is faster than FA2.
|
||||
from vllm.config import get_current_vllm_config
|
||||
|
||||
vllm_config = get_current_vllm_config()
|
||||
return (
|
||||
not envs.VLLM_DISABLE_FLASHINFER_PREFILL
|
||||
not vllm_config.attention_config.disable_flashinfer_prefill
|
||||
and flashinfer_available
|
||||
and not envs.VLLM_USE_CUDNN_PREFILL
|
||||
and not envs.VLLM_USE_TRTLLM_RAGGED_DEEPSEEK_PREFILL
|
||||
and not vllm_config.attention_config.use_cudnn_prefill
|
||||
and not vllm_config.attention_config.use_trtllm_ragged_deepseek_prefill
|
||||
and current_platform.is_device_capability(100)
|
||||
)
|
||||
|
||||
|
||||
def use_cudnn_prefill() -> bool:
|
||||
from vllm.config import get_current_vllm_config
|
||||
|
||||
vllm_config = get_current_vllm_config()
|
||||
return (
|
||||
flashinfer_available
|
||||
and envs.VLLM_USE_CUDNN_PREFILL
|
||||
and vllm_config.attention_config.use_cudnn_prefill
|
||||
and current_platform.is_device_capability(100)
|
||||
and has_nvidia_artifactory()
|
||||
)
|
||||
@ -458,9 +464,12 @@ def use_cudnn_prefill() -> bool:
|
||||
|
||||
def use_trtllm_ragged_deepseek_prefill() -> bool:
|
||||
"""Check if TRT-LLM ragged DeepSeek prefill should be used."""
|
||||
from vllm.config import get_current_vllm_config
|
||||
|
||||
vllm_config = get_current_vllm_config()
|
||||
return (
|
||||
flashinfer_available
|
||||
and envs.VLLM_USE_TRTLLM_RAGGED_DEEPSEEK_PREFILL
|
||||
and vllm_config.attention_config.use_trtllm_ragged_deepseek_prefill
|
||||
and current_platform.is_device_capability(100)
|
||||
)
|
||||
|
||||
|
||||
@ -6,7 +6,6 @@ from typing import ClassVar
|
||||
|
||||
import torch
|
||||
|
||||
from vllm import envs
|
||||
from vllm.attention.backends.abstract import (
|
||||
AttentionLayer,
|
||||
AttentionType,
|
||||
@ -131,7 +130,9 @@ class FlashAttnMLAMetadataBuilder(MLACommonMetadataBuilder[FlashAttnMLAMetadata]
|
||||
# When using cuda graph, we need to set the upper bound of the
|
||||
# number of splits so that large enough intermediate buffers are
|
||||
# pre-allocated during capture.
|
||||
self.max_num_splits = envs.VLLM_FLASH_ATTN_MAX_NUM_SPLITS_FOR_CUDA_GRAPH
|
||||
self.max_num_splits = (
|
||||
vllm_config.attention_config.flash_attn_max_num_splits_for_cuda_graph
|
||||
)
|
||||
|
||||
if vllm_is_batch_invariant():
|
||||
self.max_num_splits = 1
|
||||
|
||||
@ -165,7 +165,7 @@ class RocmAttentionBackend(AttentionBackend):
|
||||
raise ValueError(
|
||||
f"Head size {head_size} is not supported by {attn_type}. "
|
||||
f"Supported head sizes are: {cls.get_supported_head_sizes()}. "
|
||||
"Set VLLM_ATTENTION_BACKEND=FLEX_ATTENTION to use "
|
||||
"Set --attention-config.backend=FLEX_ATTENTION to use "
|
||||
"FlexAttention backend which supports all head sizes."
|
||||
)
|
||||
|
||||
|
||||
@ -210,9 +210,6 @@ class TritonAttentionImpl(AttentionImpl):
|
||||
def fused_output_quant_supported(self, quant_key: QuantKey):
|
||||
return quant_key == kFp8StaticTensorSym
|
||||
|
||||
def supports_quant_query_input(self) -> bool:
|
||||
return current_platform.is_cuda()
|
||||
|
||||
def __init__(
|
||||
self,
|
||||
num_heads: int,
|
||||
@ -262,6 +259,8 @@ class TritonAttentionImpl(AttentionImpl):
|
||||
f"num_heads: {num_heads}."
|
||||
)
|
||||
|
||||
self.supports_quant_query_input = current_platform.is_cuda()
|
||||
|
||||
def forward(
|
||||
self,
|
||||
layer: torch.nn.Module,
|
||||
|
||||
Loading…
x
Reference in New Issue
Block a user