mirror of
https://git.datalinker.icu/vllm-project/vllm.git
synced 2025-12-15 11:15:01 +08:00
[Attention][UX][1/N] Add AttentionConfig and change attention env vars to CLI arguments (#26315)
Signed-off-by: Matthew Bonanni <mbonanni@redhat.com> Signed-off-by: Matthew Bonanni <mbonanni001@gmail.com> Co-authored-by: gemini-code-assist[bot] <176961590+gemini-code-assist[bot]@users.noreply.github.com> Co-authored-by: Harry Mellor <19981378+hmellor@users.noreply.github.com> Co-authored-by: Lucas Wilkinson <LucasWilkinson@users.noreply.github.com>
This commit is contained in:
parent
dff0a2b394
commit
66e674cdd5
@ -12,13 +12,13 @@ from vllm._custom_ops import cutlass_scaled_fp4_mm, scaled_fp4_quant
|
|||||||
from vllm.attention.backends.abstract import AttentionMetadata
|
from vllm.attention.backends.abstract import AttentionMetadata
|
||||||
from vllm.attention.backends.registry import AttentionBackendEnum
|
from vllm.attention.backends.registry import AttentionBackendEnum
|
||||||
from vllm.attention.layer import Attention
|
from vllm.attention.layer import Attention
|
||||||
from vllm.attention.selector import global_force_attn_backend_context_manager
|
|
||||||
from vllm.compilation.fusion_attn import ATTN_OP, AttnFusionPass
|
from vllm.compilation.fusion_attn import ATTN_OP, AttnFusionPass
|
||||||
from vllm.compilation.fx_utils import find_op_nodes
|
from vllm.compilation.fx_utils import find_op_nodes
|
||||||
from vllm.compilation.matcher_utils import QUANT_OPS
|
from vllm.compilation.matcher_utils import QUANT_OPS
|
||||||
from vllm.compilation.noop_elimination import NoOpEliminationPass
|
from vllm.compilation.noop_elimination import NoOpEliminationPass
|
||||||
from vllm.compilation.post_cleanup import PostCleanupPass
|
from vllm.compilation.post_cleanup import PostCleanupPass
|
||||||
from vllm.config import (
|
from vllm.config import (
|
||||||
|
AttentionConfig,
|
||||||
CacheConfig,
|
CacheConfig,
|
||||||
CompilationConfig,
|
CompilationConfig,
|
||||||
CompilationMode,
|
CompilationMode,
|
||||||
@ -335,6 +335,7 @@ def test_attention_quant_pattern(
|
|||||||
custom_ops=custom_ops_list,
|
custom_ops=custom_ops_list,
|
||||||
),
|
),
|
||||||
cache_config=CacheConfig(cache_dtype="fp8"),
|
cache_config=CacheConfig(cache_dtype="fp8"),
|
||||||
|
attention_config=AttentionConfig(backend=backend),
|
||||||
)
|
)
|
||||||
|
|
||||||
# Create test inputs
|
# Create test inputs
|
||||||
@ -352,7 +353,6 @@ def test_attention_quant_pattern(
|
|||||||
with (
|
with (
|
||||||
set_current_vllm_config(vllm_config_unfused),
|
set_current_vllm_config(vllm_config_unfused),
|
||||||
set_forward_context(attn_metadata=None, vllm_config=vllm_config_unfused),
|
set_forward_context(attn_metadata=None, vllm_config=vllm_config_unfused),
|
||||||
global_force_attn_backend_context_manager(backend),
|
|
||||||
):
|
):
|
||||||
model_unfused = model_class(
|
model_unfused = model_class(
|
||||||
num_qo_heads=num_qo_heads,
|
num_qo_heads=num_qo_heads,
|
||||||
@ -378,7 +378,6 @@ def test_attention_quant_pattern(
|
|||||||
with (
|
with (
|
||||||
set_current_vllm_config(vllm_config),
|
set_current_vllm_config(vllm_config),
|
||||||
set_forward_context(attn_metadata=None, vllm_config=vllm_config),
|
set_forward_context(attn_metadata=None, vllm_config=vllm_config),
|
||||||
global_force_attn_backend_context_manager(backend),
|
|
||||||
):
|
):
|
||||||
model_fused = model_class(
|
model_fused = model_class(
|
||||||
num_qo_heads=num_qo_heads,
|
num_qo_heads=num_qo_heads,
|
||||||
|
|||||||
@ -1151,13 +1151,29 @@ def test_register_kv_caches(dist_init, attn_backend, monkeypatch):
|
|||||||
}
|
}
|
||||||
|
|
||||||
# Store tensor info for validation
|
# Store tensor info for validation
|
||||||
expected_tensor_size = shared_tensor[0].element_size() * shared_tensor[0].numel()
|
test_shape = backend_cls.get_kv_cache_shape(
|
||||||
expected_base_addrs = [
|
num_blocks=1, block_size=16, num_kv_heads=1, head_size=1
|
||||||
shared_tensor[0].data_ptr(),
|
)
|
||||||
shared_tensor[1].data_ptr(),
|
is_blocks_first = len(test_shape) == 5 and test_shape[0] == 1
|
||||||
unique_tensor[0].data_ptr(),
|
|
||||||
unique_tensor[1].data_ptr(),
|
if is_blocks_first:
|
||||||
]
|
expected_tensor_size = shared_tensor.element_size() * shared_tensor.numel()
|
||||||
|
expected_base_addrs = [
|
||||||
|
shared_tensor.data_ptr(),
|
||||||
|
unique_tensor.data_ptr(),
|
||||||
|
]
|
||||||
|
expected_num_entries = 2
|
||||||
|
else:
|
||||||
|
expected_tensor_size = (
|
||||||
|
shared_tensor[0].element_size() * shared_tensor[0].numel()
|
||||||
|
)
|
||||||
|
expected_base_addrs = [
|
||||||
|
shared_tensor[0].data_ptr(),
|
||||||
|
shared_tensor[1].data_ptr(),
|
||||||
|
unique_tensor[0].data_ptr(),
|
||||||
|
unique_tensor[1].data_ptr(),
|
||||||
|
]
|
||||||
|
expected_num_entries = 4
|
||||||
|
|
||||||
with (
|
with (
|
||||||
patch(
|
patch(
|
||||||
@ -1192,7 +1208,7 @@ def test_register_kv_caches(dist_init, attn_backend, monkeypatch):
|
|||||||
# Verify get_reg_descs was called with caches_data
|
# Verify get_reg_descs was called with caches_data
|
||||||
assert mock_wrapper_instance.get_reg_descs.called
|
assert mock_wrapper_instance.get_reg_descs.called
|
||||||
caches_data, _ = mock_wrapper_instance.get_reg_descs.call_args[0]
|
caches_data, _ = mock_wrapper_instance.get_reg_descs.call_args[0]
|
||||||
assert len(caches_data) == 4
|
assert len(caches_data) == expected_num_entries
|
||||||
|
|
||||||
for i, cache_entry in enumerate(caches_data):
|
for i, cache_entry in enumerate(caches_data):
|
||||||
base_addr, size, _tp_rank, _ = cache_entry
|
base_addr, size, _tp_rank, _ = cache_entry
|
||||||
@ -1214,7 +1230,12 @@ def test_register_kv_caches(dist_init, attn_backend, monkeypatch):
|
|||||||
f"Expected {expected_blocks_count} blocks, got {len(blocks_data)}"
|
f"Expected {expected_blocks_count} blocks, got {len(blocks_data)}"
|
||||||
)
|
)
|
||||||
|
|
||||||
expected_block_len = expected_tensor_size // 2
|
num_blocks = 2
|
||||||
|
if is_blocks_first:
|
||||||
|
expected_block_len = expected_tensor_size // num_blocks // 2
|
||||||
|
else:
|
||||||
|
expected_block_len = expected_tensor_size // num_blocks
|
||||||
|
|
||||||
for i, block_entry in enumerate(blocks_data):
|
for i, block_entry in enumerate(blocks_data):
|
||||||
block_start_addr, block_len, tp_rank = block_entry
|
block_start_addr, block_len, tp_rank = block_entry
|
||||||
assert block_len == expected_block_len, (
|
assert block_len == expected_block_len, (
|
||||||
|
|||||||
@ -6,8 +6,10 @@ import pytest
|
|||||||
import torch
|
import torch
|
||||||
|
|
||||||
from vllm.attention.backends.abstract import MultipleOf
|
from vllm.attention.backends.abstract import MultipleOf
|
||||||
|
from vllm.attention.backends.registry import AttentionBackendEnum
|
||||||
from vllm.attention.layer import Attention
|
from vllm.attention.layer import Attention
|
||||||
from vllm.config import (
|
from vllm.config import (
|
||||||
|
AttentionConfig,
|
||||||
CacheConfig,
|
CacheConfig,
|
||||||
ModelConfig,
|
ModelConfig,
|
||||||
ParallelConfig,
|
ParallelConfig,
|
||||||
@ -765,7 +767,7 @@ def test_init_kv_cache_with_kv_sharing_valid():
|
|||||||
current_platform.is_rocm(),
|
current_platform.is_rocm(),
|
||||||
reason="Attention backend FLASHINFER is not supported on ROCm.",
|
reason="Attention backend FLASHINFER is not supported on ROCm.",
|
||||||
)
|
)
|
||||||
def test_hybrid_attention_mamba_tensor_shapes(monkeypatch):
|
def test_hybrid_attention_mamba_tensor_shapes():
|
||||||
"""
|
"""
|
||||||
The GPU model runner creates different views into the
|
The GPU model runner creates different views into the
|
||||||
KVCacheTensors for the attention and mamba layers
|
KVCacheTensors for the attention and mamba layers
|
||||||
@ -806,11 +808,13 @@ def test_hybrid_attention_mamba_tensor_shapes(monkeypatch):
|
|||||||
cache_dtype="auto",
|
cache_dtype="auto",
|
||||||
)
|
)
|
||||||
parallel_config = ParallelConfig()
|
parallel_config = ParallelConfig()
|
||||||
|
attention_config = AttentionConfig(backend=AttentionBackendEnum.FLASHINFER)
|
||||||
vllm_config = VllmConfig(
|
vllm_config = VllmConfig(
|
||||||
model_config=model_config,
|
model_config=model_config,
|
||||||
cache_config=cache_config,
|
cache_config=cache_config,
|
||||||
scheduler_config=scheduler_config,
|
scheduler_config=scheduler_config,
|
||||||
parallel_config=parallel_config,
|
parallel_config=parallel_config,
|
||||||
|
attention_config=attention_config,
|
||||||
)
|
)
|
||||||
|
|
||||||
layer_0 = "model.layers.0.self_attn.attn"
|
layer_0 = "model.layers.0.self_attn.attn"
|
||||||
@ -820,8 +824,7 @@ def test_hybrid_attention_mamba_tensor_shapes(monkeypatch):
|
|||||||
layer_4 = "model.layers.4.mixer"
|
layer_4 = "model.layers.4.mixer"
|
||||||
layer_5 = "model.layers.5.mixer"
|
layer_5 = "model.layers.5.mixer"
|
||||||
|
|
||||||
with set_current_vllm_config(vllm_config), monkeypatch.context() as m:
|
with set_current_vllm_config(vllm_config):
|
||||||
m.setenv("VLLM_ATTENTION_BACKEND", "FLASHINFER")
|
|
||||||
hf_config = vllm_config.model_config.hf_config
|
hf_config = vllm_config.model_config.hf_config
|
||||||
fwd_context = {}
|
fwd_context = {}
|
||||||
for key in [layer_0, layer_1]:
|
for key in [layer_0, layer_1]:
|
||||||
@ -851,10 +854,7 @@ def test_hybrid_attention_mamba_tensor_shapes(monkeypatch):
|
|||||||
)
|
)
|
||||||
# suppress var not used error
|
# suppress var not used error
|
||||||
assert fwd_context is not None
|
assert fwd_context is not None
|
||||||
vllm_ctx = vllm_config.compilation_config.static_forward_context
|
vllm_ctx = vllm_config.compilation_config.static_forward_context
|
||||||
|
|
||||||
with monkeypatch.context() as m:
|
|
||||||
m.setenv("VLLM_ATTENTION_BACKEND", "FLASHINFER")
|
|
||||||
|
|
||||||
runner = GPUModelRunner(vllm_config, DEVICE)
|
runner = GPUModelRunner(vllm_config, DEVICE)
|
||||||
kv_cache_spec = runner.get_kv_cache_spec()
|
kv_cache_spec = runner.get_kv_cache_spec()
|
||||||
@ -865,94 +865,94 @@ def test_hybrid_attention_mamba_tensor_shapes(monkeypatch):
|
|||||||
)[0]
|
)[0]
|
||||||
runner.initialize_kv_cache(kv_cache_config)
|
runner.initialize_kv_cache(kv_cache_config)
|
||||||
|
|
||||||
# random partition of blocks
|
# random partition of blocks
|
||||||
# blocks0 will be assigned to attention layers
|
# blocks0 will be assigned to attention layers
|
||||||
# blocks1 will be assigned to mamba layers
|
# blocks1 will be assigned to mamba layers
|
||||||
num_blocks = kv_cache_config.num_blocks
|
num_blocks = kv_cache_config.num_blocks
|
||||||
ind = np.arange(num_blocks)
|
ind = np.arange(num_blocks)
|
||||||
np.random.shuffle(ind)
|
np.random.shuffle(ind)
|
||||||
blocks0, blocks1 = ind[: (num_blocks // 2)], ind[(num_blocks // 2) :]
|
blocks0, blocks1 = ind[: (num_blocks // 2)], ind[(num_blocks // 2) :]
|
||||||
|
|
||||||
attn_shape = vllm_ctx[layer_0].kv_cache[0].shape
|
attn_shape = vllm_ctx[layer_0].kv_cache[0].shape
|
||||||
conv_shape = vllm_ctx[layer_2].kv_cache[0][0].shape
|
conv_shape = vllm_ctx[layer_2].kv_cache[0][0].shape
|
||||||
ssm_shape = vllm_ctx[layer_2].kv_cache[0][1].shape
|
ssm_shape = vllm_ctx[layer_2].kv_cache[0][1].shape
|
||||||
|
|
||||||
# assert we are using FlashInfer
|
# assert we are using FlashInfer
|
||||||
assert attn_shape[0] % num_blocks == 0
|
assert attn_shape[0] % num_blocks == 0
|
||||||
block_split_ratio = attn_shape[0] // num_blocks
|
block_split_ratio = attn_shape[0] // num_blocks
|
||||||
|
|
||||||
# use small blocks for testing to avoid memory issues
|
# use small blocks for testing to avoid memory issues
|
||||||
test_block_size = min(2, len(blocks0), len(blocks1))
|
test_block_size = min(2, len(blocks0), len(blocks1))
|
||||||
|
|
||||||
# use non-overlapping blocks to avoid data contamination
|
# use non-overlapping blocks to avoid data contamination
|
||||||
# Split kernel blocks: first half for attention, second half for mamba
|
# Split kernel blocks: first half for attention, second half for mamba
|
||||||
mid_point = num_blocks // 2
|
mid_point = num_blocks // 2
|
||||||
|
|
||||||
# attention uses kernel blocks from first half (mapped to logical blocks)
|
# attention uses kernel blocks from first half (mapped to logical blocks)
|
||||||
kv_blocks_for_attention = np.array([0, 1])[:test_block_size]
|
kv_blocks_for_attention = np.array([0, 1])[:test_block_size]
|
||||||
|
|
||||||
# mamba uses kernel blocks from second half
|
# mamba uses kernel blocks from second half
|
||||||
kv_blocks_for_mamba = np.array([mid_point, mid_point + 1])[:test_block_size]
|
kv_blocks_for_mamba = np.array([mid_point, mid_point + 1])[:test_block_size]
|
||||||
|
|
||||||
# create small constant tensors for testing with corrected shapes
|
# create small constant tensors for testing with corrected shapes
|
||||||
# attention: [block_size, ...] starting from dimension 2
|
# attention: [block_size, ...] starting from dimension 2
|
||||||
attn_constant_shape = attn_shape[2:]
|
attn_constant_shape = attn_shape[2:]
|
||||||
conv_constant_shape = conv_shape[1:]
|
conv_constant_shape = conv_shape[1:]
|
||||||
ssm_constant_shape = ssm_shape[1:]
|
ssm_constant_shape = ssm_shape[1:]
|
||||||
|
|
||||||
attn_blocks_constant = torch.full(
|
attn_blocks_constant = torch.full(
|
||||||
(test_block_size, *attn_constant_shape), device=DEVICE, fill_value=3.33
|
(test_block_size, *attn_constant_shape), device=DEVICE, fill_value=3.33
|
||||||
)
|
)
|
||||||
conv_blocks_constant = torch.full(
|
conv_blocks_constant = torch.full(
|
||||||
(test_block_size, *conv_constant_shape), device=DEVICE, fill_value=6.66
|
(test_block_size, *conv_constant_shape), device=DEVICE, fill_value=6.66
|
||||||
)
|
)
|
||||||
ssm_blocks_constant = torch.full(
|
ssm_blocks_constant = torch.full(
|
||||||
(test_block_size, *ssm_constant_shape), device=DEVICE, fill_value=9.99
|
(test_block_size, *ssm_constant_shape), device=DEVICE, fill_value=9.99
|
||||||
)
|
)
|
||||||
|
|
||||||
# Fill attention blocks with constants using kv block indices
|
# Fill attention blocks with constants using kv block indices
|
||||||
kernel_blocks_for_attention = kv_blocks_for_attention * block_split_ratio
|
kernel_blocks_for_attention = kv_blocks_for_attention * block_split_ratio
|
||||||
|
|
||||||
for layer in [layer_0, layer_1]:
|
for layer in [layer_0, layer_1]:
|
||||||
# attention: kv_cache[0][kernel_block_idx, kv_idx, ...]
|
# attention: kv_cache[0][kernel_block_idx, kv_idx, ...]
|
||||||
for i, kernel_block in enumerate(kernel_blocks_for_attention):
|
for i, kernel_block in enumerate(kernel_blocks_for_attention):
|
||||||
vllm_ctx[layer].kv_cache[0][kernel_block, :] = attn_blocks_constant[i]
|
vllm_ctx[layer].kv_cache[0][kernel_block, :] = attn_blocks_constant[i]
|
||||||
|
|
||||||
# fill mamba blocks with constants using kernel block indices
|
# fill mamba blocks with constants using kernel block indices
|
||||||
for layer in [layer_2, layer_3, layer_4, layer_5]:
|
for layer in [layer_2, layer_3, layer_4, layer_5]:
|
||||||
# mamba: kv_cache[0][component][kernel_block_idx, ...]
|
# mamba: kv_cache[0][component][kernel_block_idx, ...]
|
||||||
for i, kv_block in enumerate(kv_blocks_for_mamba):
|
for i, kv_block in enumerate(kv_blocks_for_mamba):
|
||||||
vllm_ctx[layer].kv_cache[0][0][kv_block, :] = conv_blocks_constant[i]
|
vllm_ctx[layer].kv_cache[0][0][kv_block, :] = conv_blocks_constant[i]
|
||||||
vllm_ctx[layer].kv_cache[0][1][kv_block, :] = ssm_blocks_constant[i]
|
vllm_ctx[layer].kv_cache[0][1][kv_block, :] = ssm_blocks_constant[i]
|
||||||
|
|
||||||
# verify attention and mamba contents are correct
|
# verify attention and mamba contents are correct
|
||||||
for layer in [layer_0, layer_1]:
|
for layer in [layer_0, layer_1]:
|
||||||
for i, kernel_block in enumerate(kernel_blocks_for_attention):
|
for i, kernel_block in enumerate(kernel_blocks_for_attention):
|
||||||
actual_kv = vllm_ctx[layer].kv_cache[0][kernel_block, :]
|
actual_kv = vllm_ctx[layer].kv_cache[0][kernel_block, :]
|
||||||
expected = attn_blocks_constant[i]
|
expected = attn_blocks_constant[i]
|
||||||
|
|
||||||
# Check K and V separately
|
# Check K and V separately
|
||||||
assert torch.equal(actual_kv[0], expected)
|
assert torch.equal(actual_kv[0], expected)
|
||||||
assert torch.equal(actual_kv[1], expected)
|
assert torch.equal(actual_kv[1], expected)
|
||||||
|
|
||||||
for layer in [layer_2, layer_3, layer_4, layer_5]:
|
for layer in [layer_2, layer_3, layer_4, layer_5]:
|
||||||
for i, kv_block in enumerate(kv_blocks_for_mamba):
|
for i, kv_block in enumerate(kv_blocks_for_mamba):
|
||||||
actual_conv = vllm_ctx[layer].kv_cache[0][0][kv_block, :]
|
actual_conv = vllm_ctx[layer].kv_cache[0][0][kv_block, :]
|
||||||
actual_ssm = vllm_ctx[layer].kv_cache[0][1][kv_block, :]
|
actual_ssm = vllm_ctx[layer].kv_cache[0][1][kv_block, :]
|
||||||
expected_conv = conv_blocks_constant[i]
|
expected_conv = conv_blocks_constant[i]
|
||||||
expected_ssm = ssm_blocks_constant[i]
|
expected_ssm = ssm_blocks_constant[i]
|
||||||
|
|
||||||
assert torch.equal(actual_conv, expected_conv)
|
assert torch.equal(actual_conv, expected_conv)
|
||||||
assert torch.equal(actual_ssm, expected_ssm)
|
assert torch.equal(actual_ssm, expected_ssm)
|
||||||
|
|
||||||
for layer in [layer_2, layer_3, layer_4, layer_5]:
|
for layer in [layer_2, layer_3, layer_4, layer_5]:
|
||||||
for i, kv_block in enumerate(kv_blocks_for_mamba):
|
for i, kv_block in enumerate(kv_blocks_for_mamba):
|
||||||
actual_conv = vllm_ctx[layer].kv_cache[0][0][kv_block, :]
|
actual_conv = vllm_ctx[layer].kv_cache[0][0][kv_block, :]
|
||||||
actual_ssm = vllm_ctx[layer].kv_cache[0][1][kv_block, :]
|
actual_ssm = vllm_ctx[layer].kv_cache[0][1][kv_block, :]
|
||||||
expected_conv = conv_blocks_constant[i]
|
expected_conv = conv_blocks_constant[i]
|
||||||
expected_ssm = ssm_blocks_constant[i]
|
expected_ssm = ssm_blocks_constant[i]
|
||||||
assert torch.equal(actual_conv, expected_conv)
|
assert torch.equal(actual_conv, expected_conv)
|
||||||
assert torch.equal(actual_ssm, expected_ssm)
|
assert torch.equal(actual_ssm, expected_ssm)
|
||||||
|
|
||||||
|
|
||||||
def test_hybrid_block_table_initialization():
|
def test_hybrid_block_table_initialization():
|
||||||
|
|||||||
@ -289,6 +289,16 @@ class AttentionImpl(ABC, Generic[T]):
|
|||||||
# even if they can return lse (for efficiency reasons)
|
# even if they can return lse (for efficiency reasons)
|
||||||
need_to_return_lse_for_decode: bool = False
|
need_to_return_lse_for_decode: bool = False
|
||||||
|
|
||||||
|
# Whether this attention implementation supports pre-quantized query input.
|
||||||
|
# When True, the attention layer will quantize queries before passing them
|
||||||
|
# to this backend, allowing torch.compile to fuse the quantization with
|
||||||
|
# previous operations. This is typically supported when using FP8 KV cache
|
||||||
|
# with compatible attention kernels (e.g., TRT-LLM).
|
||||||
|
# Subclasses should set this in __init__.
|
||||||
|
# TODO add support to more backends:
|
||||||
|
# https://github.com/vllm-project/vllm/issues/25584
|
||||||
|
supports_quant_query_input: bool = False
|
||||||
|
|
||||||
dcp_world_size: int
|
dcp_world_size: int
|
||||||
dcp_rank: int
|
dcp_rank: int
|
||||||
|
|
||||||
@ -368,22 +378,6 @@ class AttentionImpl(ABC, Generic[T]):
|
|||||||
"""
|
"""
|
||||||
return False
|
return False
|
||||||
|
|
||||||
def supports_quant_query_input(self) -> bool:
|
|
||||||
"""
|
|
||||||
Check if this attention implementation supports pre-quantized query input.
|
|
||||||
|
|
||||||
When True, the attention layer will quantize queries before passing them
|
|
||||||
to this backend, allowing torch.compile to fuse the quantization with
|
|
||||||
previous operations. This is typically supported when using FP8 KV cache
|
|
||||||
with compatible attention kernels (e.g., TRT-LLM).
|
|
||||||
TODO add support to more backends:
|
|
||||||
https://github.com/vllm-project/vllm/issues/25584
|
|
||||||
|
|
||||||
Returns:
|
|
||||||
bool: True if the implementation can accept pre-quantized queries.
|
|
||||||
"""
|
|
||||||
return False
|
|
||||||
|
|
||||||
def process_weights_after_loading(self, act_dtype: torch.dtype):
|
def process_weights_after_loading(self, act_dtype: torch.dtype):
|
||||||
pass
|
pass
|
||||||
|
|
||||||
|
|||||||
@ -303,7 +303,7 @@ class Attention(nn.Module, AttentionLayerBase):
|
|||||||
self.query_quant = None
|
self.query_quant = None
|
||||||
if (
|
if (
|
||||||
self.kv_cache_dtype.startswith("fp8")
|
self.kv_cache_dtype.startswith("fp8")
|
||||||
and self.impl.supports_quant_query_input()
|
and self.impl.supports_quant_query_input
|
||||||
):
|
):
|
||||||
self.query_quant = QuantFP8(static=True, group_shape=GroupShape.PER_TENSOR)
|
self.query_quant = QuantFP8(static=True, group_shape=GroupShape.PER_TENSOR)
|
||||||
|
|
||||||
@ -338,7 +338,7 @@ class Attention(nn.Module, AttentionLayerBase):
|
|||||||
assert self.kv_cache_dtype in {"fp8", "fp8_e4m3"}
|
assert self.kv_cache_dtype in {"fp8", "fp8_e4m3"}
|
||||||
|
|
||||||
# check if query quantization is supported
|
# check if query quantization is supported
|
||||||
if self.impl.supports_quant_query_input():
|
if self.impl.supports_quant_query_input:
|
||||||
query, _ = self.query_quant(query, self._q_scale)
|
query, _ = self.query_quant(query, self._q_scale)
|
||||||
|
|
||||||
if self.use_output:
|
if self.use_output:
|
||||||
|
|||||||
@ -2,19 +2,14 @@
|
|||||||
# SPDX-FileCopyrightText: Copyright contributors to the vLLM project
|
# SPDX-FileCopyrightText: Copyright contributors to the vLLM project
|
||||||
|
|
||||||
import inspect
|
import inspect
|
||||||
import os
|
|
||||||
from collections.abc import Generator
|
|
||||||
from contextlib import contextmanager
|
|
||||||
from functools import cache
|
from functools import cache
|
||||||
from typing import cast, get_args
|
from typing import cast, get_args
|
||||||
|
|
||||||
import torch
|
import torch
|
||||||
|
|
||||||
import vllm.envs as envs
|
|
||||||
from vllm.attention.backends.abstract import AttentionBackend
|
from vllm.attention.backends.abstract import AttentionBackend
|
||||||
from vllm.attention.backends.registry import (
|
from vllm.attention.backends.registry import (
|
||||||
MAMBA_TYPE_TO_BACKEND_MAP,
|
MAMBA_TYPE_TO_BACKEND_MAP,
|
||||||
AttentionBackendEnum,
|
|
||||||
MambaAttentionBackendEnum,
|
MambaAttentionBackendEnum,
|
||||||
)
|
)
|
||||||
from vllm.config.cache import CacheDType
|
from vllm.config.cache import CacheDType
|
||||||
@ -24,60 +19,6 @@ from vllm.utils.import_utils import resolve_obj_by_qualname
|
|||||||
logger = init_logger(__name__)
|
logger = init_logger(__name__)
|
||||||
|
|
||||||
|
|
||||||
def get_env_variable_attn_backend() -> AttentionBackendEnum | None:
|
|
||||||
"""
|
|
||||||
Get the backend override specified by the vLLM attention
|
|
||||||
backend environment variable, if one is specified.
|
|
||||||
|
|
||||||
Returns:
|
|
||||||
|
|
||||||
* AttentionBackendEnum value if an override is specified
|
|
||||||
* None otherwise
|
|
||||||
"""
|
|
||||||
backend_name = os.environ.get("VLLM_ATTENTION_BACKEND")
|
|
||||||
if backend_name is None:
|
|
||||||
return None
|
|
||||||
if backend_name == "XFORMERS":
|
|
||||||
raise ValueError(
|
|
||||||
"Attention backend 'XFORMERS' has been removed (See PR #29262 for "
|
|
||||||
"details). Please select a supported attention backend."
|
|
||||||
)
|
|
||||||
return AttentionBackendEnum[backend_name]
|
|
||||||
|
|
||||||
|
|
||||||
# Global state allows a particular choice of backend
|
|
||||||
# to be forced, overriding the logic which auto-selects
|
|
||||||
# a backend based on system & workload configuration
|
|
||||||
# (default behavior if this variable is None)
|
|
||||||
#
|
|
||||||
# THIS SELECTION TAKES PRECEDENCE OVER THE
|
|
||||||
# VLLM_ATTENTION_BACKEND ENVIRONMENT VARIABLE
|
|
||||||
forced_attn_backend: AttentionBackendEnum | None = None
|
|
||||||
|
|
||||||
|
|
||||||
def global_force_attn_backend(attn_backend: AttentionBackendEnum | None) -> None:
|
|
||||||
"""
|
|
||||||
Force all attention operations to use a specified backend.
|
|
||||||
|
|
||||||
Passing `None` for the argument re-enables automatic
|
|
||||||
backend selection.,
|
|
||||||
|
|
||||||
Arguments:
|
|
||||||
|
|
||||||
* attn_backend: backend selection (None to revert to auto)
|
|
||||||
"""
|
|
||||||
global forced_attn_backend
|
|
||||||
forced_attn_backend = attn_backend
|
|
||||||
|
|
||||||
|
|
||||||
def get_global_forced_attn_backend() -> AttentionBackendEnum | None:
|
|
||||||
"""
|
|
||||||
Get the currently-forced choice of attention backend,
|
|
||||||
or None if auto-selection is currently enabled.
|
|
||||||
"""
|
|
||||||
return forced_attn_backend
|
|
||||||
|
|
||||||
|
|
||||||
def get_attn_backend(
|
def get_attn_backend(
|
||||||
head_size: int,
|
head_size: int,
|
||||||
dtype: torch.dtype,
|
dtype: torch.dtype,
|
||||||
@ -97,7 +38,13 @@ def get_attn_backend(
|
|||||||
f"Valid values are: {valid_cache_dtypes}"
|
f"Valid values are: {valid_cache_dtypes}"
|
||||||
)
|
)
|
||||||
|
|
||||||
|
from vllm.config import get_current_vllm_config
|
||||||
|
|
||||||
|
vllm_config = get_current_vllm_config()
|
||||||
|
backend_enum = vllm_config.attention_config.backend
|
||||||
|
|
||||||
return _cached_get_attn_backend(
|
return _cached_get_attn_backend(
|
||||||
|
backend=backend_enum,
|
||||||
head_size=head_size,
|
head_size=head_size,
|
||||||
dtype=dtype,
|
dtype=dtype,
|
||||||
kv_cache_dtype=cast(CacheDType | None, kv_cache_dtype),
|
kv_cache_dtype=cast(CacheDType | None, kv_cache_dtype),
|
||||||
@ -111,6 +58,7 @@ def get_attn_backend(
|
|||||||
|
|
||||||
@cache
|
@cache
|
||||||
def _cached_get_attn_backend(
|
def _cached_get_attn_backend(
|
||||||
|
backend,
|
||||||
head_size: int,
|
head_size: int,
|
||||||
dtype: torch.dtype,
|
dtype: torch.dtype,
|
||||||
kv_cache_dtype: CacheDType | None,
|
kv_cache_dtype: CacheDType | None,
|
||||||
@ -120,39 +68,6 @@ def _cached_get_attn_backend(
|
|||||||
use_sparse: bool = False,
|
use_sparse: bool = False,
|
||||||
attn_type: str | None = None,
|
attn_type: str | None = None,
|
||||||
) -> type[AttentionBackend]:
|
) -> type[AttentionBackend]:
|
||||||
# Check whether a particular choice of backend was
|
|
||||||
# previously forced.
|
|
||||||
#
|
|
||||||
# THIS SELECTION OVERRIDES THE VLLM_ATTENTION_BACKEND
|
|
||||||
# ENVIRONMENT VARIABLE.
|
|
||||||
selected_backend = None
|
|
||||||
backend_by_global_setting: AttentionBackendEnum | None = (
|
|
||||||
get_global_forced_attn_backend()
|
|
||||||
)
|
|
||||||
if backend_by_global_setting is not None:
|
|
||||||
selected_backend = backend_by_global_setting
|
|
||||||
else:
|
|
||||||
# Check the environment variable and override if specified
|
|
||||||
backend_by_env_var: str | None = envs.VLLM_ATTENTION_BACKEND
|
|
||||||
if backend_by_env_var is not None:
|
|
||||||
if backend_by_env_var.endswith("_VLLM_V1"):
|
|
||||||
logger.warning(
|
|
||||||
"The suffix '_VLLM_V1' in the environment variable "
|
|
||||||
"VLLM_ATTENTION_BACKEND is no longer necessary as "
|
|
||||||
"V0 backends have been deprecated. "
|
|
||||||
"Please remove this suffix from your "
|
|
||||||
"environment variable setting.",
|
|
||||||
)
|
|
||||||
backend_by_env_var = backend_by_env_var.removesuffix("_VLLM_V1")
|
|
||||||
try:
|
|
||||||
selected_backend = AttentionBackendEnum[backend_by_env_var]
|
|
||||||
except KeyError as e:
|
|
||||||
raise ValueError(
|
|
||||||
f"Invalid attention backend: '{backend_by_env_var}'. Valid "
|
|
||||||
f"backends are: {list(AttentionBackendEnum.__members__.keys())}"
|
|
||||||
) from e
|
|
||||||
|
|
||||||
# get device-specific attn_backend
|
|
||||||
from vllm.platforms import current_platform
|
from vllm.platforms import current_platform
|
||||||
|
|
||||||
sig = inspect.signature(current_platform.get_attn_backend_cls)
|
sig = inspect.signature(current_platform.get_attn_backend_cls)
|
||||||
@ -163,7 +78,7 @@ def _cached_get_attn_backend(
|
|||||||
"remove it from your plugin code."
|
"remove it from your plugin code."
|
||||||
)
|
)
|
||||||
attention_cls = current_platform.get_attn_backend_cls(
|
attention_cls = current_platform.get_attn_backend_cls(
|
||||||
selected_backend,
|
backend,
|
||||||
head_size,
|
head_size,
|
||||||
dtype,
|
dtype,
|
||||||
kv_cache_dtype,
|
kv_cache_dtype,
|
||||||
@ -176,7 +91,7 @@ def _cached_get_attn_backend(
|
|||||||
)
|
)
|
||||||
else:
|
else:
|
||||||
attention_cls = current_platform.get_attn_backend_cls(
|
attention_cls = current_platform.get_attn_backend_cls(
|
||||||
selected_backend,
|
backend,
|
||||||
head_size,
|
head_size,
|
||||||
dtype,
|
dtype,
|
||||||
kv_cache_dtype,
|
kv_cache_dtype,
|
||||||
@ -232,37 +147,3 @@ def _cached_get_mamba_attn_backend(
|
|||||||
|
|
||||||
mamba_attn_backend = selected_backend.get_class()
|
mamba_attn_backend = selected_backend.get_class()
|
||||||
return mamba_attn_backend
|
return mamba_attn_backend
|
||||||
|
|
||||||
|
|
||||||
@contextmanager
|
|
||||||
def global_force_attn_backend_context_manager(
|
|
||||||
attn_backend: AttentionBackendEnum,
|
|
||||||
) -> Generator[None, None, None]:
|
|
||||||
"""
|
|
||||||
Globally force a vLLM attention backend override within a
|
|
||||||
context manager, reverting the global attention backend
|
|
||||||
override to its prior state upon exiting the context
|
|
||||||
manager.
|
|
||||||
|
|
||||||
Arguments:
|
|
||||||
|
|
||||||
* attn_backend: attention backend to force
|
|
||||||
|
|
||||||
Returns:
|
|
||||||
|
|
||||||
* Generator
|
|
||||||
"""
|
|
||||||
|
|
||||||
# Save the current state of the global backend override (if any)
|
|
||||||
original_value = get_global_forced_attn_backend()
|
|
||||||
|
|
||||||
# Globally force the new backend override
|
|
||||||
global_force_attn_backend(attn_backend)
|
|
||||||
|
|
||||||
# Yield control back to the enclosed code block
|
|
||||||
try:
|
|
||||||
yield
|
|
||||||
finally:
|
|
||||||
# Revert the original global backend override, if any
|
|
||||||
global_force_attn_backend(original_value)
|
|
||||||
_cached_get_attn_backend.cache_clear()
|
|
||||||
|
|||||||
@ -1,7 +1,6 @@
|
|||||||
# SPDX-License-Identifier: Apache-2.0
|
# SPDX-License-Identifier: Apache-2.0
|
||||||
# SPDX-FileCopyrightText: Copyright contributors to the vLLM project
|
# SPDX-FileCopyrightText: Copyright contributors to the vLLM project
|
||||||
|
|
||||||
from vllm import envs
|
|
||||||
from vllm.logger import init_logger
|
from vllm.logger import init_logger
|
||||||
from vllm.platforms import current_platform
|
from vllm.platforms import current_platform
|
||||||
|
|
||||||
@ -49,10 +48,12 @@ def get_flash_attn_version(requires_alibi: bool = False) -> int | None:
|
|||||||
3 if (device_capability.major == 9 and is_fa_version_supported(3)) else 2
|
3 if (device_capability.major == 9 and is_fa_version_supported(3)) else 2
|
||||||
)
|
)
|
||||||
|
|
||||||
# 2. override if passed by environment
|
# 2. override if passed by environment or config
|
||||||
if envs.VLLM_FLASH_ATTN_VERSION is not None:
|
from vllm.config import get_current_vllm_config
|
||||||
assert envs.VLLM_FLASH_ATTN_VERSION in [2, 3]
|
|
||||||
fa_version = envs.VLLM_FLASH_ATTN_VERSION
|
vllm_config = get_current_vllm_config()
|
||||||
|
if vllm_config.attention_config.flash_attn_version is not None:
|
||||||
|
fa_version = vllm_config.attention_config.flash_attn_version
|
||||||
|
|
||||||
# 3. fallback for unsupported combinations
|
# 3. fallback for unsupported combinations
|
||||||
if device_capability.major == 10 and fa_version == 3:
|
if device_capability.major == 10 and fa_version == 3:
|
||||||
|
|||||||
@ -1,6 +1,7 @@
|
|||||||
# SPDX-License-Identifier: Apache-2.0
|
# SPDX-License-Identifier: Apache-2.0
|
||||||
# SPDX-FileCopyrightText: Copyright contributors to the vLLM project
|
# SPDX-FileCopyrightText: Copyright contributors to the vLLM project
|
||||||
|
|
||||||
|
from vllm.config.attention import AttentionConfig
|
||||||
from vllm.config.cache import CacheConfig
|
from vllm.config.cache import CacheConfig
|
||||||
from vllm.config.compilation import (
|
from vllm.config.compilation import (
|
||||||
CompilationConfig,
|
CompilationConfig,
|
||||||
@ -46,6 +47,8 @@ from vllm.config.vllm import (
|
|||||||
# __all__ should only contain classes and functions.
|
# __all__ should only contain classes and functions.
|
||||||
# Types and globals should be imported from their respective modules.
|
# Types and globals should be imported from their respective modules.
|
||||||
__all__ = [
|
__all__ = [
|
||||||
|
# From vllm.config.attention
|
||||||
|
"AttentionConfig",
|
||||||
# From vllm.config.cache
|
# From vllm.config.cache
|
||||||
"CacheConfig",
|
"CacheConfig",
|
||||||
# From vllm.config.compilation
|
# From vllm.config.compilation
|
||||||
|
|||||||
114
vllm/config/attention.py
Normal file
114
vllm/config/attention.py
Normal file
@ -0,0 +1,114 @@
|
|||||||
|
# SPDX-License-Identifier: Apache-2.0
|
||||||
|
# SPDX-FileCopyrightText: Copyright contributors to the vLLM project
|
||||||
|
|
||||||
|
from typing import Any, Literal
|
||||||
|
|
||||||
|
from pydantic import field_validator
|
||||||
|
from pydantic.dataclasses import dataclass
|
||||||
|
|
||||||
|
from vllm.attention.backends.registry import AttentionBackendEnum
|
||||||
|
from vllm.config.utils import config
|
||||||
|
from vllm.logger import init_logger
|
||||||
|
|
||||||
|
logger = init_logger(__name__)
|
||||||
|
|
||||||
|
|
||||||
|
@config
|
||||||
|
@dataclass
|
||||||
|
class AttentionConfig:
|
||||||
|
"""Configuration for attention mechanisms in vLLM."""
|
||||||
|
|
||||||
|
backend: AttentionBackendEnum | None = None
|
||||||
|
"""Attention backend to use. If None, will be selected automatically."""
|
||||||
|
|
||||||
|
flash_attn_version: Literal[2, 3] | None = None
|
||||||
|
"""Force vllm to use a specific flash-attention version (2 or 3).
|
||||||
|
Only valid when using the flash-attention backend."""
|
||||||
|
|
||||||
|
use_prefill_decode_attention: bool = False
|
||||||
|
"""Use separate prefill and decode kernels for attention instead of
|
||||||
|
the unified triton kernel."""
|
||||||
|
|
||||||
|
flash_attn_max_num_splits_for_cuda_graph: int = 32
|
||||||
|
"""Flash Attention max number splits for cuda graph decode."""
|
||||||
|
|
||||||
|
use_cudnn_prefill: bool = False
|
||||||
|
"""Whether to use cudnn prefill."""
|
||||||
|
|
||||||
|
use_trtllm_ragged_deepseek_prefill: bool = False
|
||||||
|
"""Whether to use TRTLLM ragged deepseek prefill."""
|
||||||
|
|
||||||
|
use_trtllm_attention: bool | None = None
|
||||||
|
"""If set to True/False, use or don't use the TRTLLM attention backend
|
||||||
|
in flashinfer. If None, auto-detect the attention backend in flashinfer."""
|
||||||
|
|
||||||
|
disable_flashinfer_prefill: bool = False
|
||||||
|
"""Whether to disable flashinfer prefill."""
|
||||||
|
|
||||||
|
disable_flashinfer_q_quantization: bool = False
|
||||||
|
"""If set, when using fp8 kv, do not quantize Q to fp8."""
|
||||||
|
|
||||||
|
def compute_hash(self) -> str:
|
||||||
|
"""
|
||||||
|
Provide a hash that uniquely identifies all the configs
|
||||||
|
that affect the structure of the computation
|
||||||
|
graph from input ids/embeddings to the final hidden states,
|
||||||
|
excluding anything before input ids/embeddings and after
|
||||||
|
the final hidden states.
|
||||||
|
"""
|
||||||
|
from vllm.config.utils import get_hash_factors, hash_factors
|
||||||
|
|
||||||
|
ignored_factors: list[str] = []
|
||||||
|
factors = get_hash_factors(self, ignored_factors)
|
||||||
|
return hash_factors(factors)
|
||||||
|
|
||||||
|
@field_validator("backend", mode="before")
|
||||||
|
@classmethod
|
||||||
|
def validate_backend_before(cls, value: Any) -> Any:
|
||||||
|
"""Enable parsing of the `backend` enum type from string."""
|
||||||
|
if isinstance(value, str):
|
||||||
|
return AttentionBackendEnum[value.upper()]
|
||||||
|
return value
|
||||||
|
|
||||||
|
def _set_from_env_if_set(self, field_name: str, env_var_name: str) -> None:
|
||||||
|
"""Set field from env var if set, with deprecation warning."""
|
||||||
|
from vllm import envs
|
||||||
|
|
||||||
|
if envs.is_set(env_var_name):
|
||||||
|
value = getattr(envs, env_var_name)
|
||||||
|
if field_name == "backend":
|
||||||
|
value = self.validate_backend_before(value)
|
||||||
|
setattr(self, field_name, value)
|
||||||
|
logger.warning_once(
|
||||||
|
"Using %s environment variable is deprecated and will be removed in "
|
||||||
|
"v0.14.0 or v1.0.0, whichever is soonest. Please use "
|
||||||
|
"--attention-config.%s command line argument or "
|
||||||
|
"AttentionConfig(%s=...) config field instead.",
|
||||||
|
env_var_name,
|
||||||
|
field_name,
|
||||||
|
field_name,
|
||||||
|
)
|
||||||
|
|
||||||
|
def __post_init__(self) -> None:
|
||||||
|
self._set_from_env_if_set("backend", "VLLM_ATTENTION_BACKEND")
|
||||||
|
self._set_from_env_if_set("flash_attn_version", "VLLM_FLASH_ATTN_VERSION")
|
||||||
|
self._set_from_env_if_set(
|
||||||
|
"use_prefill_decode_attention", "VLLM_V1_USE_PREFILL_DECODE_ATTENTION"
|
||||||
|
)
|
||||||
|
self._set_from_env_if_set(
|
||||||
|
"flash_attn_max_num_splits_for_cuda_graph",
|
||||||
|
"VLLM_FLASH_ATTN_MAX_NUM_SPLITS_FOR_CUDA_GRAPH",
|
||||||
|
)
|
||||||
|
self._set_from_env_if_set("use_cudnn_prefill", "VLLM_USE_CUDNN_PREFILL")
|
||||||
|
self._set_from_env_if_set(
|
||||||
|
"use_trtllm_ragged_deepseek_prefill",
|
||||||
|
"VLLM_USE_TRTLLM_RAGGED_DEEPSEEK_PREFILL",
|
||||||
|
)
|
||||||
|
self._set_from_env_if_set("use_trtllm_attention", "VLLM_USE_TRTLLM_ATTENTION")
|
||||||
|
self._set_from_env_if_set(
|
||||||
|
"disable_flashinfer_prefill", "VLLM_DISABLE_FLASHINFER_PREFILL"
|
||||||
|
)
|
||||||
|
self._set_from_env_if_set(
|
||||||
|
"disable_flashinfer_q_quantization",
|
||||||
|
"VLLM_FLASHINFER_DISABLE_Q_QUANTIZATION",
|
||||||
|
)
|
||||||
@ -4,7 +4,6 @@
|
|||||||
import warnings
|
import warnings
|
||||||
from collections.abc import Callable
|
from collections.abc import Callable
|
||||||
from dataclasses import InitVar, field
|
from dataclasses import InitVar, field
|
||||||
from importlib.util import find_spec
|
|
||||||
from typing import TYPE_CHECKING, Any, Literal, cast, get_args
|
from typing import TYPE_CHECKING, Any, Literal, cast, get_args
|
||||||
|
|
||||||
import torch
|
import torch
|
||||||
@ -467,18 +466,6 @@ class ModelConfig:
|
|||||||
|
|
||||||
self.maybe_pull_model_tokenizer_for_runai(self.model, self.tokenizer)
|
self.maybe_pull_model_tokenizer_for_runai(self.model, self.tokenizer)
|
||||||
|
|
||||||
if (
|
|
||||||
(backend := envs.VLLM_ATTENTION_BACKEND)
|
|
||||||
and backend == "FLASHINFER"
|
|
||||||
and find_spec("flashinfer") is None
|
|
||||||
):
|
|
||||||
raise ValueError(
|
|
||||||
"VLLM_ATTENTION_BACKEND is set to FLASHINFER, but flashinfer "
|
|
||||||
"module was not found. See "
|
|
||||||
"https://github.com/vllm-project/vllm/blob/main/docker/Dockerfile " # noqa: E501
|
|
||||||
"for instructions on how to install it."
|
|
||||||
)
|
|
||||||
|
|
||||||
from vllm.platforms import current_platform
|
from vllm.platforms import current_platform
|
||||||
|
|
||||||
if self.override_attention_dtype is not None and not current_platform.is_rocm():
|
if self.override_attention_dtype is not None and not current_platform.is_rocm():
|
||||||
|
|||||||
@ -27,6 +27,7 @@ from vllm.transformers_utils.runai_utils import is_runai_obj_uri
|
|||||||
from vllm.utils import random_uuid
|
from vllm.utils import random_uuid
|
||||||
from vllm.utils.hashing import safe_hash
|
from vllm.utils.hashing import safe_hash
|
||||||
|
|
||||||
|
from .attention import AttentionConfig
|
||||||
from .cache import CacheConfig
|
from .cache import CacheConfig
|
||||||
from .compilation import CompilationConfig, CompilationMode, CUDAGraphMode
|
from .compilation import CompilationConfig, CompilationMode, CUDAGraphMode
|
||||||
from .device import DeviceConfig
|
from .device import DeviceConfig
|
||||||
@ -192,6 +193,8 @@ class VllmConfig:
|
|||||||
"""Device configuration."""
|
"""Device configuration."""
|
||||||
load_config: LoadConfig = Field(default_factory=LoadConfig)
|
load_config: LoadConfig = Field(default_factory=LoadConfig)
|
||||||
"""Load configuration."""
|
"""Load configuration."""
|
||||||
|
attention_config: AttentionConfig = Field(default_factory=AttentionConfig)
|
||||||
|
"""Attention configuration."""
|
||||||
lora_config: LoRAConfig | None = None
|
lora_config: LoRAConfig | None = None
|
||||||
"""LoRA configuration."""
|
"""LoRA configuration."""
|
||||||
speculative_config: SpeculativeConfig | None = None
|
speculative_config: SpeculativeConfig | None = None
|
||||||
@ -279,6 +282,10 @@ class VllmConfig:
|
|||||||
vllm_factors.append(self.load_config.compute_hash())
|
vllm_factors.append(self.load_config.compute_hash())
|
||||||
else:
|
else:
|
||||||
vllm_factors.append("None")
|
vllm_factors.append("None")
|
||||||
|
if self.attention_config:
|
||||||
|
vllm_factors.append(self.attention_config.compute_hash())
|
||||||
|
else:
|
||||||
|
vllm_factors.append("None")
|
||||||
if self.lora_config:
|
if self.lora_config:
|
||||||
vllm_factors.append(self.lora_config.compute_hash())
|
vllm_factors.append(self.lora_config.compute_hash())
|
||||||
else:
|
else:
|
||||||
|
|||||||
@ -34,6 +34,7 @@ from typing_extensions import TypeIs
|
|||||||
import vllm.envs as envs
|
import vllm.envs as envs
|
||||||
from vllm.attention.backends.registry import AttentionBackendEnum
|
from vllm.attention.backends.registry import AttentionBackendEnum
|
||||||
from vllm.config import (
|
from vllm.config import (
|
||||||
|
AttentionConfig,
|
||||||
CacheConfig,
|
CacheConfig,
|
||||||
CompilationConfig,
|
CompilationConfig,
|
||||||
ConfigType,
|
ConfigType,
|
||||||
@ -527,6 +528,7 @@ class EngineArgs:
|
|||||||
|
|
||||||
pooler_config: PoolerConfig | None = ModelConfig.pooler_config
|
pooler_config: PoolerConfig | None = ModelConfig.pooler_config
|
||||||
compilation_config: CompilationConfig = get_field(VllmConfig, "compilation_config")
|
compilation_config: CompilationConfig = get_field(VllmConfig, "compilation_config")
|
||||||
|
attention_config: AttentionConfig = get_field(VllmConfig, "attention_config")
|
||||||
worker_cls: str = ParallelConfig.worker_cls
|
worker_cls: str = ParallelConfig.worker_cls
|
||||||
worker_extension_cls: str = ParallelConfig.worker_extension_cls
|
worker_extension_cls: str = ParallelConfig.worker_extension_cls
|
||||||
|
|
||||||
@ -542,6 +544,7 @@ class EngineArgs:
|
|||||||
)
|
)
|
||||||
model_impl: str = ModelConfig.model_impl
|
model_impl: str = ModelConfig.model_impl
|
||||||
override_attention_dtype: str = ModelConfig.override_attention_dtype
|
override_attention_dtype: str = ModelConfig.override_attention_dtype
|
||||||
|
attention_backend: AttentionBackendEnum | None = AttentionConfig.backend
|
||||||
|
|
||||||
calculate_kv_scales: bool = CacheConfig.calculate_kv_scales
|
calculate_kv_scales: bool = CacheConfig.calculate_kv_scales
|
||||||
mamba_cache_dtype: MambaDType = CacheConfig.mamba_cache_dtype
|
mamba_cache_dtype: MambaDType = CacheConfig.mamba_cache_dtype
|
||||||
@ -580,6 +583,8 @@ class EngineArgs:
|
|||||||
# CompilationConfig object
|
# CompilationConfig object
|
||||||
if isinstance(self.compilation_config, dict):
|
if isinstance(self.compilation_config, dict):
|
||||||
self.compilation_config = CompilationConfig(**self.compilation_config)
|
self.compilation_config = CompilationConfig(**self.compilation_config)
|
||||||
|
if isinstance(self.attention_config, dict):
|
||||||
|
self.attention_config = AttentionConfig(**self.attention_config)
|
||||||
if isinstance(self.eplb_config, dict):
|
if isinstance(self.eplb_config, dict):
|
||||||
self.eplb_config = EPLBConfig(**self.eplb_config)
|
self.eplb_config = EPLBConfig(**self.eplb_config)
|
||||||
# Setup plugins
|
# Setup plugins
|
||||||
@ -717,6 +722,16 @@ class EngineArgs:
|
|||||||
"--pt-load-map-location", **load_kwargs["pt_load_map_location"]
|
"--pt-load-map-location", **load_kwargs["pt_load_map_location"]
|
||||||
)
|
)
|
||||||
|
|
||||||
|
# Attention arguments
|
||||||
|
attention_kwargs = get_kwargs(AttentionConfig)
|
||||||
|
attention_group = parser.add_argument_group(
|
||||||
|
title="AttentionConfig",
|
||||||
|
description=AttentionConfig.__doc__,
|
||||||
|
)
|
||||||
|
attention_group.add_argument(
|
||||||
|
"--attention-backend", **attention_kwargs["backend"]
|
||||||
|
)
|
||||||
|
|
||||||
# Structured outputs arguments
|
# Structured outputs arguments
|
||||||
structured_outputs_kwargs = get_kwargs(StructuredOutputsConfig)
|
structured_outputs_kwargs = get_kwargs(StructuredOutputsConfig)
|
||||||
structured_outputs_group = parser.add_argument_group(
|
structured_outputs_group = parser.add_argument_group(
|
||||||
@ -1140,6 +1155,9 @@ class EngineArgs:
|
|||||||
vllm_group.add_argument(
|
vllm_group.add_argument(
|
||||||
"--compilation-config", "-cc", **vllm_kwargs["compilation_config"]
|
"--compilation-config", "-cc", **vllm_kwargs["compilation_config"]
|
||||||
)
|
)
|
||||||
|
vllm_group.add_argument(
|
||||||
|
"--attention-config", "-ac", **vllm_kwargs["attention_config"]
|
||||||
|
)
|
||||||
vllm_group.add_argument(
|
vllm_group.add_argument(
|
||||||
"--additional-config", **vllm_kwargs["additional_config"]
|
"--additional-config", **vllm_kwargs["additional_config"]
|
||||||
)
|
)
|
||||||
@ -1693,6 +1711,16 @@ class EngineArgs:
|
|||||||
if model_config.quantization == "bitsandbytes":
|
if model_config.quantization == "bitsandbytes":
|
||||||
self.quantization = self.load_format = "bitsandbytes"
|
self.quantization = self.load_format = "bitsandbytes"
|
||||||
|
|
||||||
|
# Attention config overrides
|
||||||
|
attention_config = copy.deepcopy(self.attention_config)
|
||||||
|
if self.attention_backend is not None:
|
||||||
|
if attention_config.backend is not None:
|
||||||
|
raise ValueError(
|
||||||
|
"attention_backend and attention_config.backend "
|
||||||
|
"are mutually exclusive"
|
||||||
|
)
|
||||||
|
attention_config.backend = self.attention_backend
|
||||||
|
|
||||||
load_config = self.create_load_config()
|
load_config = self.create_load_config()
|
||||||
|
|
||||||
# Pass reasoning_parser into StructuredOutputsConfig
|
# Pass reasoning_parser into StructuredOutputsConfig
|
||||||
@ -1750,9 +1778,10 @@ class EngineArgs:
|
|||||||
parallel_config=parallel_config,
|
parallel_config=parallel_config,
|
||||||
scheduler_config=scheduler_config,
|
scheduler_config=scheduler_config,
|
||||||
device_config=device_config,
|
device_config=device_config,
|
||||||
|
load_config=load_config,
|
||||||
|
attention_config=attention_config,
|
||||||
lora_config=lora_config,
|
lora_config=lora_config,
|
||||||
speculative_config=speculative_config,
|
speculative_config=speculative_config,
|
||||||
load_config=load_config,
|
|
||||||
structured_outputs_config=self.structured_outputs_config,
|
structured_outputs_config=self.structured_outputs_config,
|
||||||
observability_config=observability_config,
|
observability_config=observability_config,
|
||||||
compilation_config=compilation_config,
|
compilation_config=compilation_config,
|
||||||
|
|||||||
@ -4,7 +4,7 @@ from copy import deepcopy
|
|||||||
from math import lcm
|
from math import lcm
|
||||||
from typing import TYPE_CHECKING
|
from typing import TYPE_CHECKING
|
||||||
|
|
||||||
import vllm.envs as envs
|
from vllm.attention.backends.registry import AttentionBackendEnum
|
||||||
from vllm.logger import init_logger
|
from vllm.logger import init_logger
|
||||||
from vllm.model_executor.models import ModelRegistry
|
from vllm.model_executor.models import ModelRegistry
|
||||||
from vllm.platforms import current_platform
|
from vllm.platforms import current_platform
|
||||||
@ -331,6 +331,7 @@ class HybridAttentionMambaModelConfig(VerifyAndUpdateConfig):
|
|||||||
# Enable FULL_AND_PIECEWISE by default
|
# Enable FULL_AND_PIECEWISE by default
|
||||||
MambaModelConfig.verify_and_update_config(vllm_config)
|
MambaModelConfig.verify_and_update_config(vllm_config)
|
||||||
|
|
||||||
|
attention_config = vllm_config.attention_config
|
||||||
cache_config = vllm_config.cache_config
|
cache_config = vllm_config.cache_config
|
||||||
model_config = vllm_config.model_config
|
model_config = vllm_config.model_config
|
||||||
parallel_config = vllm_config.parallel_config
|
parallel_config = vllm_config.parallel_config
|
||||||
@ -347,7 +348,9 @@ class HybridAttentionMambaModelConfig(VerifyAndUpdateConfig):
|
|||||||
# * CUTLASS_MLA backend: kernel_block_size 128 alignment
|
# * CUTLASS_MLA backend: kernel_block_size 128 alignment
|
||||||
# * Other MLA backends: kernel_block_size 64 alignment
|
# * Other MLA backends: kernel_block_size 64 alignment
|
||||||
if model_config.use_mla:
|
if model_config.use_mla:
|
||||||
use_cutlass_mla = envs.VLLM_ATTENTION_BACKEND == "CUTLASS_MLA"
|
use_cutlass_mla = (
|
||||||
|
attention_config.backend == AttentionBackendEnum.CUTLASS_MLA
|
||||||
|
)
|
||||||
kernel_block_alignment_size = 128 if use_cutlass_mla else 64
|
kernel_block_alignment_size = 128 if use_cutlass_mla else 64
|
||||||
attn_page_size_1_token = MLAAttentionSpec(
|
attn_page_size_1_token = MLAAttentionSpec(
|
||||||
block_size=1,
|
block_size=1,
|
||||||
@ -361,8 +364,8 @@ class HybridAttentionMambaModelConfig(VerifyAndUpdateConfig):
|
|||||||
current_platform.is_device_capability(100)
|
current_platform.is_device_capability(100)
|
||||||
and model_config.get_head_size() == 256
|
and model_config.get_head_size() == 256
|
||||||
and (
|
and (
|
||||||
envs.VLLM_ATTENTION_BACKEND is None
|
attention_config.backend is None
|
||||||
or envs.VLLM_ATTENTION_BACKEND == "FLASHINFER"
|
or attention_config.backend == AttentionBackendEnum.FLASHINFER
|
||||||
)
|
)
|
||||||
):
|
):
|
||||||
# https://github.com/flashinfer-ai/flashinfer/issues/1993 reports that`
|
# https://github.com/flashinfer-ai/flashinfer/issues/1993 reports that`
|
||||||
|
|||||||
@ -11,7 +11,7 @@ import torch
|
|||||||
from transformers import PretrainedConfig
|
from transformers import PretrainedConfig
|
||||||
|
|
||||||
from vllm.attention.backends.registry import AttentionBackendEnum
|
from vllm.attention.backends.registry import AttentionBackendEnum
|
||||||
from vllm.config import VllmConfig
|
from vllm.config import VllmConfig, get_current_vllm_config
|
||||||
from vllm.distributed import (
|
from vllm.distributed import (
|
||||||
get_tensor_model_parallel_rank,
|
get_tensor_model_parallel_rank,
|
||||||
get_tensor_model_parallel_world_size,
|
get_tensor_model_parallel_world_size,
|
||||||
@ -91,10 +91,7 @@ def get_vit_attn_backend(
|
|||||||
if attn_backend_override is not None:
|
if attn_backend_override is not None:
|
||||||
return attn_backend_override
|
return attn_backend_override
|
||||||
|
|
||||||
# Lazy import to avoid circular dependency
|
selected_backend = get_current_vllm_config().attention_config.backend
|
||||||
from vllm.attention.selector import get_env_variable_attn_backend
|
|
||||||
|
|
||||||
selected_backend: AttentionBackendEnum | None = get_env_variable_attn_backend()
|
|
||||||
if selected_backend is not None:
|
if selected_backend is not None:
|
||||||
return selected_backend
|
return selected_backend
|
||||||
|
|
||||||
|
|||||||
@ -14,7 +14,6 @@ from typing_extensions import ParamSpec
|
|||||||
|
|
||||||
# import custom ops, trigger op registration
|
# import custom ops, trigger op registration
|
||||||
import vllm._C # noqa
|
import vllm._C # noqa
|
||||||
import vllm.envs as envs
|
|
||||||
from vllm.attention.backends.abstract import AttentionType
|
from vllm.attention.backends.abstract import AttentionType
|
||||||
from vllm.attention.backends.registry import AttentionBackendEnum
|
from vllm.attention.backends.registry import AttentionBackendEnum
|
||||||
from vllm.logger import init_logger
|
from vllm.logger import init_logger
|
||||||
@ -149,6 +148,8 @@ class CudaPlatformBase(Platform):
|
|||||||
|
|
||||||
@classmethod
|
@classmethod
|
||||||
def check_and_update_config(cls, vllm_config: "VllmConfig") -> None:
|
def check_and_update_config(cls, vllm_config: "VllmConfig") -> None:
|
||||||
|
from vllm.attention.backends.registry import AttentionBackendEnum
|
||||||
|
|
||||||
parallel_config = vllm_config.parallel_config
|
parallel_config = vllm_config.parallel_config
|
||||||
model_config = vllm_config.model_config
|
model_config = vllm_config.model_config
|
||||||
|
|
||||||
@ -171,7 +172,7 @@ class CudaPlatformBase(Platform):
|
|||||||
and cache_config.block_size is not None
|
and cache_config.block_size is not None
|
||||||
):
|
):
|
||||||
use_sparse = hasattr(vllm_config.model_config.hf_config, "index_topk")
|
use_sparse = hasattr(vllm_config.model_config.hf_config, "index_topk")
|
||||||
# If `VLLM_ATTENTION_BACKEND` is not set and we are using MLA,
|
# If `--attention-config.backend` is not set and we are using MLA,
|
||||||
# then we default to FlashMLA backend for non-blackwell GPUs,
|
# then we default to FlashMLA backend for non-blackwell GPUs,
|
||||||
# else we default to CutlassMLA. For each case, we force the
|
# else we default to CutlassMLA. For each case, we force the
|
||||||
# required block_size.
|
# required block_size.
|
||||||
@ -179,23 +180,25 @@ class CudaPlatformBase(Platform):
|
|||||||
use_cutlass_mla = False
|
use_cutlass_mla = False
|
||||||
use_flashinfer_mla = False
|
use_flashinfer_mla = False
|
||||||
|
|
||||||
if envs.VLLM_ATTENTION_BACKEND is None:
|
if vllm_config.attention_config.backend is None:
|
||||||
# Default case
|
# Default case
|
||||||
if cls.is_device_capability(100):
|
if cls.is_device_capability(100):
|
||||||
# Blackwell => Force CutlassMLA.
|
# Blackwell => Force CutlassMLA.
|
||||||
use_cutlass_mla = True
|
use_cutlass_mla = True
|
||||||
# TODO: This does not work, because the
|
# Set the backend in AttentionConfig so it's used during
|
||||||
# global_force_attn_backend_context_manager is not set.
|
# backend selection
|
||||||
# See vllm/attention/selector.py:_cached_get_attn_backend
|
vllm_config.attention_config.backend = (
|
||||||
envs.VLLM_ATTENTION_BACKEND = "CUTLASS_MLA"
|
AttentionBackendEnum.CUTLASS_MLA
|
||||||
|
)
|
||||||
else:
|
else:
|
||||||
# Not Blackwell
|
# Not Blackwell
|
||||||
use_flashmla = True
|
use_flashmla = True
|
||||||
else:
|
else:
|
||||||
# Forced case
|
# Forced case
|
||||||
use_flashmla = envs.VLLM_ATTENTION_BACKEND == "FLASHMLA"
|
backend = vllm_config.attention_config.backend
|
||||||
use_cutlass_mla = envs.VLLM_ATTENTION_BACKEND == "CUTLASS_MLA"
|
use_flashmla = backend == AttentionBackendEnum.FLASHMLA
|
||||||
use_flashinfer_mla = envs.VLLM_ATTENTION_BACKEND == "FLASHINFER_MLA"
|
use_cutlass_mla = backend == AttentionBackendEnum.CUTLASS_MLA
|
||||||
|
use_flashinfer_mla = backend == AttentionBackendEnum.FLASHINFER_MLA
|
||||||
|
|
||||||
from vllm.attention.ops.flashmla import is_flashmla_dense_supported
|
from vllm.attention.ops.flashmla import is_flashmla_dense_supported
|
||||||
|
|
||||||
|
|||||||
@ -267,21 +267,16 @@ def supports_trtllm_attention() -> bool:
|
|||||||
return current_platform.is_device_capability(100) and has_nvidia_artifactory()
|
return current_platform.is_device_capability(100) and has_nvidia_artifactory()
|
||||||
|
|
||||||
|
|
||||||
@functools.cache
|
|
||||||
def _force_use_trtllm_attention(env_value: bool | None) -> bool | None:
|
|
||||||
"""Cache the env value for VLLM_USE_TRTLLM_ATTENTION"""
|
|
||||||
if env_value is not None:
|
|
||||||
logger.info_once("VLLM_USE_TRTLLM_ATTENTION is set to %s", env_value)
|
|
||||||
return env_value
|
|
||||||
|
|
||||||
|
|
||||||
def force_use_trtllm_attention() -> bool | None:
|
def force_use_trtllm_attention() -> bool | None:
|
||||||
"""
|
"""
|
||||||
Return `None` if VLLM_USE_TRTLLM_ATTENTION is not set,
|
Return `None` if --attention-config.use_trtllm_attention is not set,
|
||||||
return `True` if TRTLLM attention is forced to be used,
|
return `True` if TRTLLM attention is forced to be used,
|
||||||
return `False` if TRTLLM attention is forced to be not used.
|
return `False` if TRTLLM attention is forced to be not used.
|
||||||
"""
|
"""
|
||||||
return _force_use_trtllm_attention(envs.VLLM_USE_TRTLLM_ATTENTION)
|
from vllm.config import get_current_vllm_config
|
||||||
|
|
||||||
|
vllm_config = get_current_vllm_config()
|
||||||
|
return vllm_config.attention_config.use_trtllm_attention
|
||||||
|
|
||||||
|
|
||||||
def can_use_trtllm_attention(num_qo_heads: int, num_kv_heads: int) -> bool:
|
def can_use_trtllm_attention(num_qo_heads: int, num_kv_heads: int) -> bool:
|
||||||
@ -307,7 +302,7 @@ def use_trtllm_attention(
|
|||||||
"""Return `True` if TRTLLM attention is used."""
|
"""Return `True` if TRTLLM attention is used."""
|
||||||
force_use_trtllm = force_use_trtllm_attention()
|
force_use_trtllm = force_use_trtllm_attention()
|
||||||
|
|
||||||
# Environment variable is set to 0 - respect it
|
# CLI argument is set to 0 - respect it
|
||||||
if force_use_trtllm is not None and not force_use_trtllm:
|
if force_use_trtllm is not None and not force_use_trtllm:
|
||||||
return False
|
return False
|
||||||
|
|
||||||
@ -324,7 +319,7 @@ def use_trtllm_attention(
|
|||||||
if force_use_trtllm:
|
if force_use_trtllm:
|
||||||
logger.warning_once(
|
logger.warning_once(
|
||||||
"TRTLLM attention is not supported on this platform, "
|
"TRTLLM attention is not supported on this platform, "
|
||||||
"but VLLM_USE_TRTLLM_ATTENTION is set to 1"
|
"but --attention-config.use_trtllm_attention is set to 1"
|
||||||
)
|
)
|
||||||
return False
|
return False
|
||||||
|
|
||||||
@ -333,7 +328,8 @@ def use_trtllm_attention(
|
|||||||
if force_use_trtllm:
|
if force_use_trtllm:
|
||||||
logger.warning_once(
|
logger.warning_once(
|
||||||
"TRTLLM attention is not supported for this combination of "
|
"TRTLLM attention is not supported for this combination of "
|
||||||
"query and key heads, but VLLM_USE_TRTLLM_ATTENTION is set to 1"
|
"query and key heads, but --attention-config.use_trtllm_attention is "
|
||||||
|
"set to 1"
|
||||||
)
|
)
|
||||||
return False
|
return False
|
||||||
|
|
||||||
@ -354,7 +350,7 @@ def use_trtllm_attention(
|
|||||||
return True
|
return True
|
||||||
|
|
||||||
if force_use_trtllm is None:
|
if force_use_trtllm is None:
|
||||||
# Environment variable not set - use auto-detection
|
# CLI argument not set - use auto-detection
|
||||||
if is_prefill:
|
if is_prefill:
|
||||||
# Prefill auto-detection
|
# Prefill auto-detection
|
||||||
use_trtllm = kv_cache_dtype == "auto"
|
use_trtllm = kv_cache_dtype == "auto"
|
||||||
@ -367,8 +363,10 @@ def use_trtllm_attention(
|
|||||||
logger.warning_once("Using TRTLLM decode attention (auto-detected).")
|
logger.warning_once("Using TRTLLM decode attention (auto-detected).")
|
||||||
return use_trtllm
|
return use_trtllm
|
||||||
|
|
||||||
# Environment variable is set to 1 - respect it
|
# CLI argument is set to 1 - respect it
|
||||||
logger.info_once("Using TRTLLM attention (VLLM_USE_TRTLLM_ATTENTION is set to 1)")
|
logger.info_once(
|
||||||
|
"Using TRTLLM attention (--attention-config.use_trtllm_attention is set to 1)"
|
||||||
|
)
|
||||||
return True
|
return True
|
||||||
|
|
||||||
|
|
||||||
@ -500,12 +498,6 @@ def flashinfer_scaled_fp8_mm(
|
|||||||
return output
|
return output
|
||||||
|
|
||||||
|
|
||||||
@functools.cache
|
|
||||||
def flashinfer_disable_q_quantization() -> bool:
|
|
||||||
"""Cache result which only depends on the environment"""
|
|
||||||
return envs.VLLM_FLASHINFER_DISABLE_Q_QUANTIZATION
|
|
||||||
|
|
||||||
|
|
||||||
__all__ = [
|
__all__ = [
|
||||||
"has_flashinfer",
|
"has_flashinfer",
|
||||||
"flashinfer_trtllm_fp8_block_scale_moe",
|
"flashinfer_trtllm_fp8_block_scale_moe",
|
||||||
@ -526,7 +518,6 @@ __all__ = [
|
|||||||
"supports_trtllm_attention",
|
"supports_trtllm_attention",
|
||||||
"can_use_trtllm_attention",
|
"can_use_trtllm_attention",
|
||||||
"use_trtllm_attention",
|
"use_trtllm_attention",
|
||||||
"flashinfer_disable_q_quantization",
|
|
||||||
"flashinfer_scaled_fp4_mm",
|
"flashinfer_scaled_fp4_mm",
|
||||||
"flashinfer_scaled_fp8_mm",
|
"flashinfer_scaled_fp8_mm",
|
||||||
]
|
]
|
||||||
|
|||||||
@ -8,7 +8,6 @@ from typing import ClassVar
|
|||||||
import numpy as np
|
import numpy as np
|
||||||
import torch
|
import torch
|
||||||
|
|
||||||
from vllm import envs
|
|
||||||
from vllm.attention.backends.abstract import (
|
from vllm.attention.backends.abstract import (
|
||||||
AttentionBackend,
|
AttentionBackend,
|
||||||
AttentionImpl,
|
AttentionImpl,
|
||||||
@ -264,6 +263,7 @@ class FlashAttentionMetadataBuilder(AttentionMetadataBuilder[FlashAttentionMetad
|
|||||||
self.parallel_config = vllm_config.parallel_config
|
self.parallel_config = vllm_config.parallel_config
|
||||||
self.cache_config = vllm_config.cache_config
|
self.cache_config = vllm_config.cache_config
|
||||||
self.compilation_config = vllm_config.compilation_config
|
self.compilation_config = vllm_config.compilation_config
|
||||||
|
self.attention_config = vllm_config.attention_config
|
||||||
|
|
||||||
self.num_heads_q = self.model_config.get_num_attention_heads(
|
self.num_heads_q = self.model_config.get_num_attention_heads(
|
||||||
self.parallel_config
|
self.parallel_config
|
||||||
@ -304,7 +304,9 @@ class FlashAttentionMetadataBuilder(AttentionMetadataBuilder[FlashAttentionMetad
|
|||||||
# When using cuda graph, we need to set the upper bound of the
|
# When using cuda graph, we need to set the upper bound of the
|
||||||
# number of splits so that large enough intermediate buffers are
|
# number of splits so that large enough intermediate buffers are
|
||||||
# pre-allocated during capture.
|
# pre-allocated during capture.
|
||||||
self.max_num_splits = envs.VLLM_FLASH_ATTN_MAX_NUM_SPLITS_FOR_CUDA_GRAPH
|
self.max_num_splits = (
|
||||||
|
self.attention_config.flash_attn_max_num_splits_for_cuda_graph
|
||||||
|
)
|
||||||
|
|
||||||
# Sliding window size to be used with the AOT scheduler will be
|
# Sliding window size to be used with the AOT scheduler will be
|
||||||
# populated on first build() call.
|
# populated on first build() call.
|
||||||
@ -554,8 +556,7 @@ class FlashAttentionImpl(AttentionImpl):
|
|||||||
"heads in the layer"
|
"heads in the layer"
|
||||||
)
|
)
|
||||||
|
|
||||||
def supports_quant_query_input(self) -> bool:
|
self.supports_quant_query_input = True
|
||||||
return True
|
|
||||||
|
|
||||||
def forward(
|
def forward(
|
||||||
self,
|
self,
|
||||||
|
|||||||
@ -26,7 +26,7 @@ from vllm.attention.backends.abstract import (
|
|||||||
)
|
)
|
||||||
from vllm.attention.ops.common import cp_lse_ag_out_rs
|
from vllm.attention.ops.common import cp_lse_ag_out_rs
|
||||||
from vllm.attention.ops.merge_attn_states import merge_attn_states
|
from vllm.attention.ops.merge_attn_states import merge_attn_states
|
||||||
from vllm.config import CUDAGraphMode, VllmConfig
|
from vllm.config import CUDAGraphMode, VllmConfig, get_current_vllm_config
|
||||||
from vllm.config.cache import CacheDType
|
from vllm.config.cache import CacheDType
|
||||||
from vllm.distributed.parallel_state import get_dcp_group
|
from vllm.distributed.parallel_state import get_dcp_group
|
||||||
from vllm.logger import init_logger
|
from vllm.logger import init_logger
|
||||||
@ -43,7 +43,6 @@ from vllm.platforms.interface import DeviceCapability
|
|||||||
from vllm.triton_utils import tl, triton
|
from vllm.triton_utils import tl, triton
|
||||||
from vllm.utils.flashinfer import (
|
from vllm.utils.flashinfer import (
|
||||||
can_use_trtllm_attention,
|
can_use_trtllm_attention,
|
||||||
flashinfer_disable_q_quantization,
|
|
||||||
use_trtllm_attention,
|
use_trtllm_attention,
|
||||||
)
|
)
|
||||||
from vllm.utils.math_utils import cdiv
|
from vllm.utils.math_utils import cdiv
|
||||||
@ -362,7 +361,8 @@ class FlashInferBackend(AttentionBackend):
|
|||||||
supports_trtllm_attention,
|
supports_trtllm_attention,
|
||||||
)
|
)
|
||||||
|
|
||||||
# Respect explicit disable flag (e.g., VLLM_USE_TRTLLM_ATTENTION=0)
|
# Respect explicit disable flag (e.g.,
|
||||||
|
# --attention-config.use_trtllm_attention=0)
|
||||||
if force_use_trtllm_attention() is False:
|
if force_use_trtllm_attention() is False:
|
||||||
return False
|
return False
|
||||||
|
|
||||||
@ -500,11 +500,14 @@ class FlashInferMetadataBuilder(AttentionMetadataBuilder[FlashInferMetadata]):
|
|||||||
self.kv_cache_dtype = self.kv_cache_spec.dtype
|
self.kv_cache_dtype = self.kv_cache_spec.dtype
|
||||||
|
|
||||||
# Use model dtype as q dtype when TRTLLM attn is not supported, or
|
# Use model dtype as q dtype when TRTLLM attn is not supported, or
|
||||||
# VLLM_FLASHINFER_DISABLE_Q_QUANTIZATION is set to 1. Otherwise, try to
|
# --attention-config.disable_flashinfer_q_quantization is set to 1. Otherwise,
|
||||||
# use fp8 q if kv cache is fp8, and will fall back to model dtype
|
# try to use fp8 q if kv cache is fp8, and will fall back to model dtype
|
||||||
# if TRTLLM attention kernel is not used when building attn metadata
|
# if TRTLLM attention kernel is not used when building attn metadata
|
||||||
can_use_trtllm = can_use_trtllm_attention(self.num_qo_heads, self.num_kv_heads)
|
can_use_trtllm = can_use_trtllm_attention(self.num_qo_heads, self.num_kv_heads)
|
||||||
if can_use_trtllm and not flashinfer_disable_q_quantization():
|
if (
|
||||||
|
can_use_trtllm
|
||||||
|
and not vllm_config.attention_config.disable_flashinfer_q_quantization
|
||||||
|
):
|
||||||
self.q_data_type = self.kv_cache_dtype
|
self.q_data_type = self.kv_cache_dtype
|
||||||
else:
|
else:
|
||||||
self.q_data_type = self.model_config.dtype
|
self.q_data_type = self.model_config.dtype
|
||||||
@ -1035,6 +1038,11 @@ class FlashInferImpl(AttentionImpl):
|
|||||||
self.sinks = sinks
|
self.sinks = sinks
|
||||||
|
|
||||||
self.support_trtllm_attn = can_use_trtllm_attention(num_heads, num_kv_heads)
|
self.support_trtllm_attn = can_use_trtllm_attention(num_heads, num_kv_heads)
|
||||||
|
vllm_config = get_current_vllm_config()
|
||||||
|
self.supports_quant_query_input = (
|
||||||
|
self.support_trtllm_attn
|
||||||
|
and not vllm_config.attention_config.disable_flashinfer_q_quantization
|
||||||
|
)
|
||||||
self.bmm1_scale: float | None = None
|
self.bmm1_scale: float | None = None
|
||||||
self.bmm2_scale: float | None = None
|
self.bmm2_scale: float | None = None
|
||||||
self.o_sf_scale: float | None = None
|
self.o_sf_scale: float | None = None
|
||||||
@ -1046,12 +1054,6 @@ class FlashInferImpl(AttentionImpl):
|
|||||||
and quant_key in (kFp8StaticTensorSym, kNvfp4Quant)
|
and quant_key in (kFp8StaticTensorSym, kNvfp4Quant)
|
||||||
)
|
)
|
||||||
|
|
||||||
def supports_quant_query_input(self) -> bool:
|
|
||||||
if flashinfer_disable_q_quantization():
|
|
||||||
return False
|
|
||||||
|
|
||||||
return self.support_trtllm_attn
|
|
||||||
|
|
||||||
# FlashInfer requires attention sinks to be float32
|
# FlashInfer requires attention sinks to be float32
|
||||||
def process_weights_after_loading(self, act_dtype: torch.dtype):
|
def process_weights_after_loading(self, act_dtype: torch.dtype):
|
||||||
if self.sinks is not None and self.sinks.dtype != torch.float32:
|
if self.sinks is not None and self.sinks.dtype != torch.float32:
|
||||||
|
|||||||
@ -438,19 +438,25 @@ A = TypeVar("A")
|
|||||||
def use_flashinfer_prefill() -> bool:
|
def use_flashinfer_prefill() -> bool:
|
||||||
# For blackwell default to flashinfer prefill if it's available since
|
# For blackwell default to flashinfer prefill if it's available since
|
||||||
# it is faster than FA2.
|
# it is faster than FA2.
|
||||||
|
from vllm.config import get_current_vllm_config
|
||||||
|
|
||||||
|
vllm_config = get_current_vllm_config()
|
||||||
return (
|
return (
|
||||||
not envs.VLLM_DISABLE_FLASHINFER_PREFILL
|
not vllm_config.attention_config.disable_flashinfer_prefill
|
||||||
and flashinfer_available
|
and flashinfer_available
|
||||||
and not envs.VLLM_USE_CUDNN_PREFILL
|
and not vllm_config.attention_config.use_cudnn_prefill
|
||||||
and not envs.VLLM_USE_TRTLLM_RAGGED_DEEPSEEK_PREFILL
|
and not vllm_config.attention_config.use_trtllm_ragged_deepseek_prefill
|
||||||
and current_platform.is_device_capability(100)
|
and current_platform.is_device_capability(100)
|
||||||
)
|
)
|
||||||
|
|
||||||
|
|
||||||
def use_cudnn_prefill() -> bool:
|
def use_cudnn_prefill() -> bool:
|
||||||
|
from vllm.config import get_current_vllm_config
|
||||||
|
|
||||||
|
vllm_config = get_current_vllm_config()
|
||||||
return (
|
return (
|
||||||
flashinfer_available
|
flashinfer_available
|
||||||
and envs.VLLM_USE_CUDNN_PREFILL
|
and vllm_config.attention_config.use_cudnn_prefill
|
||||||
and current_platform.is_device_capability(100)
|
and current_platform.is_device_capability(100)
|
||||||
and has_nvidia_artifactory()
|
and has_nvidia_artifactory()
|
||||||
)
|
)
|
||||||
@ -458,9 +464,12 @@ def use_cudnn_prefill() -> bool:
|
|||||||
|
|
||||||
def use_trtllm_ragged_deepseek_prefill() -> bool:
|
def use_trtllm_ragged_deepseek_prefill() -> bool:
|
||||||
"""Check if TRT-LLM ragged DeepSeek prefill should be used."""
|
"""Check if TRT-LLM ragged DeepSeek prefill should be used."""
|
||||||
|
from vllm.config import get_current_vllm_config
|
||||||
|
|
||||||
|
vllm_config = get_current_vllm_config()
|
||||||
return (
|
return (
|
||||||
flashinfer_available
|
flashinfer_available
|
||||||
and envs.VLLM_USE_TRTLLM_RAGGED_DEEPSEEK_PREFILL
|
and vllm_config.attention_config.use_trtllm_ragged_deepseek_prefill
|
||||||
and current_platform.is_device_capability(100)
|
and current_platform.is_device_capability(100)
|
||||||
)
|
)
|
||||||
|
|
||||||
|
|||||||
@ -6,7 +6,6 @@ from typing import ClassVar
|
|||||||
|
|
||||||
import torch
|
import torch
|
||||||
|
|
||||||
from vllm import envs
|
|
||||||
from vllm.attention.backends.abstract import (
|
from vllm.attention.backends.abstract import (
|
||||||
AttentionLayer,
|
AttentionLayer,
|
||||||
AttentionType,
|
AttentionType,
|
||||||
@ -131,7 +130,9 @@ class FlashAttnMLAMetadataBuilder(MLACommonMetadataBuilder[FlashAttnMLAMetadata]
|
|||||||
# When using cuda graph, we need to set the upper bound of the
|
# When using cuda graph, we need to set the upper bound of the
|
||||||
# number of splits so that large enough intermediate buffers are
|
# number of splits so that large enough intermediate buffers are
|
||||||
# pre-allocated during capture.
|
# pre-allocated during capture.
|
||||||
self.max_num_splits = envs.VLLM_FLASH_ATTN_MAX_NUM_SPLITS_FOR_CUDA_GRAPH
|
self.max_num_splits = (
|
||||||
|
vllm_config.attention_config.flash_attn_max_num_splits_for_cuda_graph
|
||||||
|
)
|
||||||
|
|
||||||
if vllm_is_batch_invariant():
|
if vllm_is_batch_invariant():
|
||||||
self.max_num_splits = 1
|
self.max_num_splits = 1
|
||||||
|
|||||||
@ -165,7 +165,7 @@ class RocmAttentionBackend(AttentionBackend):
|
|||||||
raise ValueError(
|
raise ValueError(
|
||||||
f"Head size {head_size} is not supported by {attn_type}. "
|
f"Head size {head_size} is not supported by {attn_type}. "
|
||||||
f"Supported head sizes are: {cls.get_supported_head_sizes()}. "
|
f"Supported head sizes are: {cls.get_supported_head_sizes()}. "
|
||||||
"Set VLLM_ATTENTION_BACKEND=FLEX_ATTENTION to use "
|
"Set --attention-config.backend=FLEX_ATTENTION to use "
|
||||||
"FlexAttention backend which supports all head sizes."
|
"FlexAttention backend which supports all head sizes."
|
||||||
)
|
)
|
||||||
|
|
||||||
|
|||||||
@ -210,9 +210,6 @@ class TritonAttentionImpl(AttentionImpl):
|
|||||||
def fused_output_quant_supported(self, quant_key: QuantKey):
|
def fused_output_quant_supported(self, quant_key: QuantKey):
|
||||||
return quant_key == kFp8StaticTensorSym
|
return quant_key == kFp8StaticTensorSym
|
||||||
|
|
||||||
def supports_quant_query_input(self) -> bool:
|
|
||||||
return current_platform.is_cuda()
|
|
||||||
|
|
||||||
def __init__(
|
def __init__(
|
||||||
self,
|
self,
|
||||||
num_heads: int,
|
num_heads: int,
|
||||||
@ -262,6 +259,8 @@ class TritonAttentionImpl(AttentionImpl):
|
|||||||
f"num_heads: {num_heads}."
|
f"num_heads: {num_heads}."
|
||||||
)
|
)
|
||||||
|
|
||||||
|
self.supports_quant_query_input = current_platform.is_cuda()
|
||||||
|
|
||||||
def forward(
|
def forward(
|
||||||
self,
|
self,
|
||||||
layer: torch.nn.Module,
|
layer: torch.nn.Module,
|
||||||
|
|||||||
Loading…
x
Reference in New Issue
Block a user