diff --git a/tests/compile/test_fusion_attn.py b/tests/compile/test_fusion_attn.py index 9b4486e56c73..db95dff5e0fc 100644 --- a/tests/compile/test_fusion_attn.py +++ b/tests/compile/test_fusion_attn.py @@ -12,13 +12,13 @@ from vllm._custom_ops import cutlass_scaled_fp4_mm, scaled_fp4_quant from vllm.attention.backends.abstract import AttentionMetadata from vllm.attention.backends.registry import AttentionBackendEnum from vllm.attention.layer import Attention -from vllm.attention.selector import global_force_attn_backend_context_manager from vllm.compilation.fusion_attn import ATTN_OP, AttnFusionPass from vllm.compilation.fx_utils import find_op_nodes from vllm.compilation.matcher_utils import QUANT_OPS from vllm.compilation.noop_elimination import NoOpEliminationPass from vllm.compilation.post_cleanup import PostCleanupPass from vllm.config import ( + AttentionConfig, CacheConfig, CompilationConfig, CompilationMode, @@ -335,6 +335,7 @@ def test_attention_quant_pattern( custom_ops=custom_ops_list, ), cache_config=CacheConfig(cache_dtype="fp8"), + attention_config=AttentionConfig(backend=backend), ) # Create test inputs @@ -352,7 +353,6 @@ def test_attention_quant_pattern( with ( set_current_vllm_config(vllm_config_unfused), set_forward_context(attn_metadata=None, vllm_config=vllm_config_unfused), - global_force_attn_backend_context_manager(backend), ): model_unfused = model_class( num_qo_heads=num_qo_heads, @@ -378,7 +378,6 @@ def test_attention_quant_pattern( with ( set_current_vllm_config(vllm_config), set_forward_context(attn_metadata=None, vllm_config=vllm_config), - global_force_attn_backend_context_manager(backend), ): model_fused = model_class( num_qo_heads=num_qo_heads, diff --git a/tests/v1/kv_connector/unit/test_nixl_connector.py b/tests/v1/kv_connector/unit/test_nixl_connector.py index 5045ae0eef33..ec9ff7315d09 100644 --- a/tests/v1/kv_connector/unit/test_nixl_connector.py +++ b/tests/v1/kv_connector/unit/test_nixl_connector.py @@ -1151,13 +1151,29 @@ def test_register_kv_caches(dist_init, attn_backend, monkeypatch): } # Store tensor info for validation - expected_tensor_size = shared_tensor[0].element_size() * shared_tensor[0].numel() - expected_base_addrs = [ - shared_tensor[0].data_ptr(), - shared_tensor[1].data_ptr(), - unique_tensor[0].data_ptr(), - unique_tensor[1].data_ptr(), - ] + test_shape = backend_cls.get_kv_cache_shape( + num_blocks=1, block_size=16, num_kv_heads=1, head_size=1 + ) + is_blocks_first = len(test_shape) == 5 and test_shape[0] == 1 + + if is_blocks_first: + expected_tensor_size = shared_tensor.element_size() * shared_tensor.numel() + expected_base_addrs = [ + shared_tensor.data_ptr(), + unique_tensor.data_ptr(), + ] + expected_num_entries = 2 + else: + expected_tensor_size = ( + shared_tensor[0].element_size() * shared_tensor[0].numel() + ) + expected_base_addrs = [ + shared_tensor[0].data_ptr(), + shared_tensor[1].data_ptr(), + unique_tensor[0].data_ptr(), + unique_tensor[1].data_ptr(), + ] + expected_num_entries = 4 with ( patch( @@ -1192,7 +1208,7 @@ def test_register_kv_caches(dist_init, attn_backend, monkeypatch): # Verify get_reg_descs was called with caches_data assert mock_wrapper_instance.get_reg_descs.called caches_data, _ = mock_wrapper_instance.get_reg_descs.call_args[0] - assert len(caches_data) == 4 + assert len(caches_data) == expected_num_entries for i, cache_entry in enumerate(caches_data): base_addr, size, _tp_rank, _ = cache_entry @@ -1214,7 +1230,12 @@ def test_register_kv_caches(dist_init, attn_backend, monkeypatch): f"Expected {expected_blocks_count} blocks, got {len(blocks_data)}" ) - expected_block_len = expected_tensor_size // 2 + num_blocks = 2 + if is_blocks_first: + expected_block_len = expected_tensor_size // num_blocks // 2 + else: + expected_block_len = expected_tensor_size // num_blocks + for i, block_entry in enumerate(blocks_data): block_start_addr, block_len, tp_rank = block_entry assert block_len == expected_block_len, ( diff --git a/tests/v1/worker/test_gpu_model_runner.py b/tests/v1/worker/test_gpu_model_runner.py index 459abcfdd53c..7b8c4268a523 100644 --- a/tests/v1/worker/test_gpu_model_runner.py +++ b/tests/v1/worker/test_gpu_model_runner.py @@ -6,8 +6,10 @@ import pytest import torch from vllm.attention.backends.abstract import MultipleOf +from vllm.attention.backends.registry import AttentionBackendEnum from vllm.attention.layer import Attention from vllm.config import ( + AttentionConfig, CacheConfig, ModelConfig, ParallelConfig, @@ -765,7 +767,7 @@ def test_init_kv_cache_with_kv_sharing_valid(): current_platform.is_rocm(), reason="Attention backend FLASHINFER is not supported on ROCm.", ) -def test_hybrid_attention_mamba_tensor_shapes(monkeypatch): +def test_hybrid_attention_mamba_tensor_shapes(): """ The GPU model runner creates different views into the KVCacheTensors for the attention and mamba layers @@ -806,11 +808,13 @@ def test_hybrid_attention_mamba_tensor_shapes(monkeypatch): cache_dtype="auto", ) parallel_config = ParallelConfig() + attention_config = AttentionConfig(backend=AttentionBackendEnum.FLASHINFER) vllm_config = VllmConfig( model_config=model_config, cache_config=cache_config, scheduler_config=scheduler_config, parallel_config=parallel_config, + attention_config=attention_config, ) layer_0 = "model.layers.0.self_attn.attn" @@ -820,8 +824,7 @@ def test_hybrid_attention_mamba_tensor_shapes(monkeypatch): layer_4 = "model.layers.4.mixer" layer_5 = "model.layers.5.mixer" - with set_current_vllm_config(vllm_config), monkeypatch.context() as m: - m.setenv("VLLM_ATTENTION_BACKEND", "FLASHINFER") + with set_current_vllm_config(vllm_config): hf_config = vllm_config.model_config.hf_config fwd_context = {} for key in [layer_0, layer_1]: @@ -851,10 +854,7 @@ def test_hybrid_attention_mamba_tensor_shapes(monkeypatch): ) # suppress var not used error assert fwd_context is not None - vllm_ctx = vllm_config.compilation_config.static_forward_context - - with monkeypatch.context() as m: - m.setenv("VLLM_ATTENTION_BACKEND", "FLASHINFER") + vllm_ctx = vllm_config.compilation_config.static_forward_context runner = GPUModelRunner(vllm_config, DEVICE) kv_cache_spec = runner.get_kv_cache_spec() @@ -865,94 +865,94 @@ def test_hybrid_attention_mamba_tensor_shapes(monkeypatch): )[0] runner.initialize_kv_cache(kv_cache_config) - # random partition of blocks - # blocks0 will be assigned to attention layers - # blocks1 will be assigned to mamba layers - num_blocks = kv_cache_config.num_blocks - ind = np.arange(num_blocks) - np.random.shuffle(ind) - blocks0, blocks1 = ind[: (num_blocks // 2)], ind[(num_blocks // 2) :] + # random partition of blocks + # blocks0 will be assigned to attention layers + # blocks1 will be assigned to mamba layers + num_blocks = kv_cache_config.num_blocks + ind = np.arange(num_blocks) + np.random.shuffle(ind) + blocks0, blocks1 = ind[: (num_blocks // 2)], ind[(num_blocks // 2) :] - attn_shape = vllm_ctx[layer_0].kv_cache[0].shape - conv_shape = vllm_ctx[layer_2].kv_cache[0][0].shape - ssm_shape = vllm_ctx[layer_2].kv_cache[0][1].shape + attn_shape = vllm_ctx[layer_0].kv_cache[0].shape + conv_shape = vllm_ctx[layer_2].kv_cache[0][0].shape + ssm_shape = vllm_ctx[layer_2].kv_cache[0][1].shape - # assert we are using FlashInfer - assert attn_shape[0] % num_blocks == 0 - block_split_ratio = attn_shape[0] // num_blocks + # assert we are using FlashInfer + assert attn_shape[0] % num_blocks == 0 + block_split_ratio = attn_shape[0] // num_blocks - # use small blocks for testing to avoid memory issues - test_block_size = min(2, len(blocks0), len(blocks1)) + # use small blocks for testing to avoid memory issues + test_block_size = min(2, len(blocks0), len(blocks1)) - # use non-overlapping blocks to avoid data contamination - # Split kernel blocks: first half for attention, second half for mamba - mid_point = num_blocks // 2 + # use non-overlapping blocks to avoid data contamination + # Split kernel blocks: first half for attention, second half for mamba + mid_point = num_blocks // 2 - # attention uses kernel blocks from first half (mapped to logical blocks) - kv_blocks_for_attention = np.array([0, 1])[:test_block_size] + # attention uses kernel blocks from first half (mapped to logical blocks) + kv_blocks_for_attention = np.array([0, 1])[:test_block_size] - # mamba uses kernel blocks from second half - kv_blocks_for_mamba = np.array([mid_point, mid_point + 1])[:test_block_size] + # mamba uses kernel blocks from second half + kv_blocks_for_mamba = np.array([mid_point, mid_point + 1])[:test_block_size] - # create small constant tensors for testing with corrected shapes - # attention: [block_size, ...] starting from dimension 2 - attn_constant_shape = attn_shape[2:] - conv_constant_shape = conv_shape[1:] - ssm_constant_shape = ssm_shape[1:] + # create small constant tensors for testing with corrected shapes + # attention: [block_size, ...] starting from dimension 2 + attn_constant_shape = attn_shape[2:] + conv_constant_shape = conv_shape[1:] + ssm_constant_shape = ssm_shape[1:] - attn_blocks_constant = torch.full( - (test_block_size, *attn_constant_shape), device=DEVICE, fill_value=3.33 - ) - conv_blocks_constant = torch.full( - (test_block_size, *conv_constant_shape), device=DEVICE, fill_value=6.66 - ) - ssm_blocks_constant = torch.full( - (test_block_size, *ssm_constant_shape), device=DEVICE, fill_value=9.99 - ) + attn_blocks_constant = torch.full( + (test_block_size, *attn_constant_shape), device=DEVICE, fill_value=3.33 + ) + conv_blocks_constant = torch.full( + (test_block_size, *conv_constant_shape), device=DEVICE, fill_value=6.66 + ) + ssm_blocks_constant = torch.full( + (test_block_size, *ssm_constant_shape), device=DEVICE, fill_value=9.99 + ) - # Fill attention blocks with constants using kv block indices - kernel_blocks_for_attention = kv_blocks_for_attention * block_split_ratio + # Fill attention blocks with constants using kv block indices + kernel_blocks_for_attention = kv_blocks_for_attention * block_split_ratio - for layer in [layer_0, layer_1]: - # attention: kv_cache[0][kernel_block_idx, kv_idx, ...] - for i, kernel_block in enumerate(kernel_blocks_for_attention): - vllm_ctx[layer].kv_cache[0][kernel_block, :] = attn_blocks_constant[i] + for layer in [layer_0, layer_1]: + # attention: kv_cache[0][kernel_block_idx, kv_idx, ...] + for i, kernel_block in enumerate(kernel_blocks_for_attention): + vllm_ctx[layer].kv_cache[0][kernel_block, :] = attn_blocks_constant[i] - # fill mamba blocks with constants using kernel block indices - for layer in [layer_2, layer_3, layer_4, layer_5]: - # mamba: kv_cache[0][component][kernel_block_idx, ...] - for i, kv_block in enumerate(kv_blocks_for_mamba): - vllm_ctx[layer].kv_cache[0][0][kv_block, :] = conv_blocks_constant[i] - vllm_ctx[layer].kv_cache[0][1][kv_block, :] = ssm_blocks_constant[i] + # fill mamba blocks with constants using kernel block indices + for layer in [layer_2, layer_3, layer_4, layer_5]: + # mamba: kv_cache[0][component][kernel_block_idx, ...] + for i, kv_block in enumerate(kv_blocks_for_mamba): + vllm_ctx[layer].kv_cache[0][0][kv_block, :] = conv_blocks_constant[i] + vllm_ctx[layer].kv_cache[0][1][kv_block, :] = ssm_blocks_constant[i] - # verify attention and mamba contents are correct - for layer in [layer_0, layer_1]: - for i, kernel_block in enumerate(kernel_blocks_for_attention): - actual_kv = vllm_ctx[layer].kv_cache[0][kernel_block, :] - expected = attn_blocks_constant[i] + # verify attention and mamba contents are correct + for layer in [layer_0, layer_1]: + for i, kernel_block in enumerate(kernel_blocks_for_attention): + actual_kv = vllm_ctx[layer].kv_cache[0][kernel_block, :] + expected = attn_blocks_constant[i] - # Check K and V separately - assert torch.equal(actual_kv[0], expected) - assert torch.equal(actual_kv[1], expected) + # Check K and V separately + assert torch.equal(actual_kv[0], expected) + assert torch.equal(actual_kv[1], expected) - for layer in [layer_2, layer_3, layer_4, layer_5]: - for i, kv_block in enumerate(kv_blocks_for_mamba): - actual_conv = vllm_ctx[layer].kv_cache[0][0][kv_block, :] - actual_ssm = vllm_ctx[layer].kv_cache[0][1][kv_block, :] - expected_conv = conv_blocks_constant[i] - expected_ssm = ssm_blocks_constant[i] + for layer in [layer_2, layer_3, layer_4, layer_5]: + for i, kv_block in enumerate(kv_blocks_for_mamba): + actual_conv = vllm_ctx[layer].kv_cache[0][0][kv_block, :] + actual_ssm = vllm_ctx[layer].kv_cache[0][1][kv_block, :] + expected_conv = conv_blocks_constant[i] + expected_ssm = ssm_blocks_constant[i] - assert torch.equal(actual_conv, expected_conv) - assert torch.equal(actual_ssm, expected_ssm) + assert torch.equal(actual_conv, expected_conv) + assert torch.equal(actual_ssm, expected_ssm) - for layer in [layer_2, layer_3, layer_4, layer_5]: - for i, kv_block in enumerate(kv_blocks_for_mamba): - actual_conv = vllm_ctx[layer].kv_cache[0][0][kv_block, :] - actual_ssm = vllm_ctx[layer].kv_cache[0][1][kv_block, :] - expected_conv = conv_blocks_constant[i] - expected_ssm = ssm_blocks_constant[i] - assert torch.equal(actual_conv, expected_conv) - assert torch.equal(actual_ssm, expected_ssm) + for layer in [layer_2, layer_3, layer_4, layer_5]: + for i, kv_block in enumerate(kv_blocks_for_mamba): + actual_conv = vllm_ctx[layer].kv_cache[0][0][kv_block, :] + actual_ssm = vllm_ctx[layer].kv_cache[0][1][kv_block, :] + expected_conv = conv_blocks_constant[i] + expected_ssm = ssm_blocks_constant[i] + assert torch.equal(actual_conv, expected_conv) + assert torch.equal(actual_ssm, expected_ssm) def test_hybrid_block_table_initialization(): diff --git a/vllm/attention/backends/abstract.py b/vllm/attention/backends/abstract.py index c290670eeacb..84cca8e68607 100644 --- a/vllm/attention/backends/abstract.py +++ b/vllm/attention/backends/abstract.py @@ -289,6 +289,16 @@ class AttentionImpl(ABC, Generic[T]): # even if they can return lse (for efficiency reasons) need_to_return_lse_for_decode: bool = False + # Whether this attention implementation supports pre-quantized query input. + # When True, the attention layer will quantize queries before passing them + # to this backend, allowing torch.compile to fuse the quantization with + # previous operations. This is typically supported when using FP8 KV cache + # with compatible attention kernels (e.g., TRT-LLM). + # Subclasses should set this in __init__. + # TODO add support to more backends: + # https://github.com/vllm-project/vllm/issues/25584 + supports_quant_query_input: bool = False + dcp_world_size: int dcp_rank: int @@ -368,22 +378,6 @@ class AttentionImpl(ABC, Generic[T]): """ return False - def supports_quant_query_input(self) -> bool: - """ - Check if this attention implementation supports pre-quantized query input. - - When True, the attention layer will quantize queries before passing them - to this backend, allowing torch.compile to fuse the quantization with - previous operations. This is typically supported when using FP8 KV cache - with compatible attention kernels (e.g., TRT-LLM). - TODO add support to more backends: - https://github.com/vllm-project/vllm/issues/25584 - - Returns: - bool: True if the implementation can accept pre-quantized queries. - """ - return False - def process_weights_after_loading(self, act_dtype: torch.dtype): pass diff --git a/vllm/attention/layer.py b/vllm/attention/layer.py index da5a62617129..8a522deedf3c 100644 --- a/vllm/attention/layer.py +++ b/vllm/attention/layer.py @@ -303,7 +303,7 @@ class Attention(nn.Module, AttentionLayerBase): self.query_quant = None if ( self.kv_cache_dtype.startswith("fp8") - and self.impl.supports_quant_query_input() + and self.impl.supports_quant_query_input ): self.query_quant = QuantFP8(static=True, group_shape=GroupShape.PER_TENSOR) @@ -338,7 +338,7 @@ class Attention(nn.Module, AttentionLayerBase): assert self.kv_cache_dtype in {"fp8", "fp8_e4m3"} # check if query quantization is supported - if self.impl.supports_quant_query_input(): + if self.impl.supports_quant_query_input: query, _ = self.query_quant(query, self._q_scale) if self.use_output: diff --git a/vllm/attention/selector.py b/vllm/attention/selector.py index a7190df3c4f1..aeb130dfe872 100644 --- a/vllm/attention/selector.py +++ b/vllm/attention/selector.py @@ -2,19 +2,14 @@ # SPDX-FileCopyrightText: Copyright contributors to the vLLM project import inspect -import os -from collections.abc import Generator -from contextlib import contextmanager from functools import cache from typing import cast, get_args import torch -import vllm.envs as envs from vllm.attention.backends.abstract import AttentionBackend from vllm.attention.backends.registry import ( MAMBA_TYPE_TO_BACKEND_MAP, - AttentionBackendEnum, MambaAttentionBackendEnum, ) from vllm.config.cache import CacheDType @@ -24,60 +19,6 @@ from vllm.utils.import_utils import resolve_obj_by_qualname logger = init_logger(__name__) -def get_env_variable_attn_backend() -> AttentionBackendEnum | None: - """ - Get the backend override specified by the vLLM attention - backend environment variable, if one is specified. - - Returns: - - * AttentionBackendEnum value if an override is specified - * None otherwise - """ - backend_name = os.environ.get("VLLM_ATTENTION_BACKEND") - if backend_name is None: - return None - if backend_name == "XFORMERS": - raise ValueError( - "Attention backend 'XFORMERS' has been removed (See PR #29262 for " - "details). Please select a supported attention backend." - ) - return AttentionBackendEnum[backend_name] - - -# Global state allows a particular choice of backend -# to be forced, overriding the logic which auto-selects -# a backend based on system & workload configuration -# (default behavior if this variable is None) -# -# THIS SELECTION TAKES PRECEDENCE OVER THE -# VLLM_ATTENTION_BACKEND ENVIRONMENT VARIABLE -forced_attn_backend: AttentionBackendEnum | None = None - - -def global_force_attn_backend(attn_backend: AttentionBackendEnum | None) -> None: - """ - Force all attention operations to use a specified backend. - - Passing `None` for the argument re-enables automatic - backend selection., - - Arguments: - - * attn_backend: backend selection (None to revert to auto) - """ - global forced_attn_backend - forced_attn_backend = attn_backend - - -def get_global_forced_attn_backend() -> AttentionBackendEnum | None: - """ - Get the currently-forced choice of attention backend, - or None if auto-selection is currently enabled. - """ - return forced_attn_backend - - def get_attn_backend( head_size: int, dtype: torch.dtype, @@ -97,7 +38,13 @@ def get_attn_backend( f"Valid values are: {valid_cache_dtypes}" ) + from vllm.config import get_current_vllm_config + + vllm_config = get_current_vllm_config() + backend_enum = vllm_config.attention_config.backend + return _cached_get_attn_backend( + backend=backend_enum, head_size=head_size, dtype=dtype, kv_cache_dtype=cast(CacheDType | None, kv_cache_dtype), @@ -111,6 +58,7 @@ def get_attn_backend( @cache def _cached_get_attn_backend( + backend, head_size: int, dtype: torch.dtype, kv_cache_dtype: CacheDType | None, @@ -120,39 +68,6 @@ def _cached_get_attn_backend( use_sparse: bool = False, attn_type: str | None = None, ) -> type[AttentionBackend]: - # Check whether a particular choice of backend was - # previously forced. - # - # THIS SELECTION OVERRIDES THE VLLM_ATTENTION_BACKEND - # ENVIRONMENT VARIABLE. - selected_backend = None - backend_by_global_setting: AttentionBackendEnum | None = ( - get_global_forced_attn_backend() - ) - if backend_by_global_setting is not None: - selected_backend = backend_by_global_setting - else: - # Check the environment variable and override if specified - backend_by_env_var: str | None = envs.VLLM_ATTENTION_BACKEND - if backend_by_env_var is not None: - if backend_by_env_var.endswith("_VLLM_V1"): - logger.warning( - "The suffix '_VLLM_V1' in the environment variable " - "VLLM_ATTENTION_BACKEND is no longer necessary as " - "V0 backends have been deprecated. " - "Please remove this suffix from your " - "environment variable setting.", - ) - backend_by_env_var = backend_by_env_var.removesuffix("_VLLM_V1") - try: - selected_backend = AttentionBackendEnum[backend_by_env_var] - except KeyError as e: - raise ValueError( - f"Invalid attention backend: '{backend_by_env_var}'. Valid " - f"backends are: {list(AttentionBackendEnum.__members__.keys())}" - ) from e - - # get device-specific attn_backend from vllm.platforms import current_platform sig = inspect.signature(current_platform.get_attn_backend_cls) @@ -163,7 +78,7 @@ def _cached_get_attn_backend( "remove it from your plugin code." ) attention_cls = current_platform.get_attn_backend_cls( - selected_backend, + backend, head_size, dtype, kv_cache_dtype, @@ -176,7 +91,7 @@ def _cached_get_attn_backend( ) else: attention_cls = current_platform.get_attn_backend_cls( - selected_backend, + backend, head_size, dtype, kv_cache_dtype, @@ -232,37 +147,3 @@ def _cached_get_mamba_attn_backend( mamba_attn_backend = selected_backend.get_class() return mamba_attn_backend - - -@contextmanager -def global_force_attn_backend_context_manager( - attn_backend: AttentionBackendEnum, -) -> Generator[None, None, None]: - """ - Globally force a vLLM attention backend override within a - context manager, reverting the global attention backend - override to its prior state upon exiting the context - manager. - - Arguments: - - * attn_backend: attention backend to force - - Returns: - - * Generator - """ - - # Save the current state of the global backend override (if any) - original_value = get_global_forced_attn_backend() - - # Globally force the new backend override - global_force_attn_backend(attn_backend) - - # Yield control back to the enclosed code block - try: - yield - finally: - # Revert the original global backend override, if any - global_force_attn_backend(original_value) - _cached_get_attn_backend.cache_clear() diff --git a/vllm/attention/utils/fa_utils.py b/vllm/attention/utils/fa_utils.py index 8a46587473e4..e38c88f4838d 100644 --- a/vllm/attention/utils/fa_utils.py +++ b/vllm/attention/utils/fa_utils.py @@ -1,7 +1,6 @@ # SPDX-License-Identifier: Apache-2.0 # SPDX-FileCopyrightText: Copyright contributors to the vLLM project -from vllm import envs from vllm.logger import init_logger from vllm.platforms import current_platform @@ -49,10 +48,12 @@ def get_flash_attn_version(requires_alibi: bool = False) -> int | None: 3 if (device_capability.major == 9 and is_fa_version_supported(3)) else 2 ) - # 2. override if passed by environment - if envs.VLLM_FLASH_ATTN_VERSION is not None: - assert envs.VLLM_FLASH_ATTN_VERSION in [2, 3] - fa_version = envs.VLLM_FLASH_ATTN_VERSION + # 2. override if passed by environment or config + from vllm.config import get_current_vllm_config + + vllm_config = get_current_vllm_config() + if vllm_config.attention_config.flash_attn_version is not None: + fa_version = vllm_config.attention_config.flash_attn_version # 3. fallback for unsupported combinations if device_capability.major == 10 and fa_version == 3: diff --git a/vllm/config/__init__.py b/vllm/config/__init__.py index dd76a722106e..0f84f3ca9d3e 100644 --- a/vllm/config/__init__.py +++ b/vllm/config/__init__.py @@ -1,6 +1,7 @@ # SPDX-License-Identifier: Apache-2.0 # SPDX-FileCopyrightText: Copyright contributors to the vLLM project +from vllm.config.attention import AttentionConfig from vllm.config.cache import CacheConfig from vllm.config.compilation import ( CompilationConfig, @@ -46,6 +47,8 @@ from vllm.config.vllm import ( # __all__ should only contain classes and functions. # Types and globals should be imported from their respective modules. __all__ = [ + # From vllm.config.attention + "AttentionConfig", # From vllm.config.cache "CacheConfig", # From vllm.config.compilation diff --git a/vllm/config/attention.py b/vllm/config/attention.py new file mode 100644 index 000000000000..dd62d88826bd --- /dev/null +++ b/vllm/config/attention.py @@ -0,0 +1,114 @@ +# SPDX-License-Identifier: Apache-2.0 +# SPDX-FileCopyrightText: Copyright contributors to the vLLM project + +from typing import Any, Literal + +from pydantic import field_validator +from pydantic.dataclasses import dataclass + +from vllm.attention.backends.registry import AttentionBackendEnum +from vllm.config.utils import config +from vllm.logger import init_logger + +logger = init_logger(__name__) + + +@config +@dataclass +class AttentionConfig: + """Configuration for attention mechanisms in vLLM.""" + + backend: AttentionBackendEnum | None = None + """Attention backend to use. If None, will be selected automatically.""" + + flash_attn_version: Literal[2, 3] | None = None + """Force vllm to use a specific flash-attention version (2 or 3). + Only valid when using the flash-attention backend.""" + + use_prefill_decode_attention: bool = False + """Use separate prefill and decode kernels for attention instead of + the unified triton kernel.""" + + flash_attn_max_num_splits_for_cuda_graph: int = 32 + """Flash Attention max number splits for cuda graph decode.""" + + use_cudnn_prefill: bool = False + """Whether to use cudnn prefill.""" + + use_trtllm_ragged_deepseek_prefill: bool = False + """Whether to use TRTLLM ragged deepseek prefill.""" + + use_trtllm_attention: bool | None = None + """If set to True/False, use or don't use the TRTLLM attention backend + in flashinfer. If None, auto-detect the attention backend in flashinfer.""" + + disable_flashinfer_prefill: bool = False + """Whether to disable flashinfer prefill.""" + + disable_flashinfer_q_quantization: bool = False + """If set, when using fp8 kv, do not quantize Q to fp8.""" + + def compute_hash(self) -> str: + """ + Provide a hash that uniquely identifies all the configs + that affect the structure of the computation + graph from input ids/embeddings to the final hidden states, + excluding anything before input ids/embeddings and after + the final hidden states. + """ + from vllm.config.utils import get_hash_factors, hash_factors + + ignored_factors: list[str] = [] + factors = get_hash_factors(self, ignored_factors) + return hash_factors(factors) + + @field_validator("backend", mode="before") + @classmethod + def validate_backend_before(cls, value: Any) -> Any: + """Enable parsing of the `backend` enum type from string.""" + if isinstance(value, str): + return AttentionBackendEnum[value.upper()] + return value + + def _set_from_env_if_set(self, field_name: str, env_var_name: str) -> None: + """Set field from env var if set, with deprecation warning.""" + from vllm import envs + + if envs.is_set(env_var_name): + value = getattr(envs, env_var_name) + if field_name == "backend": + value = self.validate_backend_before(value) + setattr(self, field_name, value) + logger.warning_once( + "Using %s environment variable is deprecated and will be removed in " + "v0.14.0 or v1.0.0, whichever is soonest. Please use " + "--attention-config.%s command line argument or " + "AttentionConfig(%s=...) config field instead.", + env_var_name, + field_name, + field_name, + ) + + def __post_init__(self) -> None: + self._set_from_env_if_set("backend", "VLLM_ATTENTION_BACKEND") + self._set_from_env_if_set("flash_attn_version", "VLLM_FLASH_ATTN_VERSION") + self._set_from_env_if_set( + "use_prefill_decode_attention", "VLLM_V1_USE_PREFILL_DECODE_ATTENTION" + ) + self._set_from_env_if_set( + "flash_attn_max_num_splits_for_cuda_graph", + "VLLM_FLASH_ATTN_MAX_NUM_SPLITS_FOR_CUDA_GRAPH", + ) + self._set_from_env_if_set("use_cudnn_prefill", "VLLM_USE_CUDNN_PREFILL") + self._set_from_env_if_set( + "use_trtllm_ragged_deepseek_prefill", + "VLLM_USE_TRTLLM_RAGGED_DEEPSEEK_PREFILL", + ) + self._set_from_env_if_set("use_trtllm_attention", "VLLM_USE_TRTLLM_ATTENTION") + self._set_from_env_if_set( + "disable_flashinfer_prefill", "VLLM_DISABLE_FLASHINFER_PREFILL" + ) + self._set_from_env_if_set( + "disable_flashinfer_q_quantization", + "VLLM_FLASHINFER_DISABLE_Q_QUANTIZATION", + ) diff --git a/vllm/config/model.py b/vllm/config/model.py index ae5189ce68d9..5be7d5e7f2df 100644 --- a/vllm/config/model.py +++ b/vllm/config/model.py @@ -4,7 +4,6 @@ import warnings from collections.abc import Callable from dataclasses import InitVar, field -from importlib.util import find_spec from typing import TYPE_CHECKING, Any, Literal, cast, get_args import torch @@ -467,18 +466,6 @@ class ModelConfig: self.maybe_pull_model_tokenizer_for_runai(self.model, self.tokenizer) - if ( - (backend := envs.VLLM_ATTENTION_BACKEND) - and backend == "FLASHINFER" - and find_spec("flashinfer") is None - ): - raise ValueError( - "VLLM_ATTENTION_BACKEND is set to FLASHINFER, but flashinfer " - "module was not found. See " - "https://github.com/vllm-project/vllm/blob/main/docker/Dockerfile " # noqa: E501 - "for instructions on how to install it." - ) - from vllm.platforms import current_platform if self.override_attention_dtype is not None and not current_platform.is_rocm(): diff --git a/vllm/config/vllm.py b/vllm/config/vllm.py index 823bd96db9ac..ce3d3b20865d 100644 --- a/vllm/config/vllm.py +++ b/vllm/config/vllm.py @@ -27,6 +27,7 @@ from vllm.transformers_utils.runai_utils import is_runai_obj_uri from vllm.utils import random_uuid from vllm.utils.hashing import safe_hash +from .attention import AttentionConfig from .cache import CacheConfig from .compilation import CompilationConfig, CompilationMode, CUDAGraphMode from .device import DeviceConfig @@ -192,6 +193,8 @@ class VllmConfig: """Device configuration.""" load_config: LoadConfig = Field(default_factory=LoadConfig) """Load configuration.""" + attention_config: AttentionConfig = Field(default_factory=AttentionConfig) + """Attention configuration.""" lora_config: LoRAConfig | None = None """LoRA configuration.""" speculative_config: SpeculativeConfig | None = None @@ -279,6 +282,10 @@ class VllmConfig: vllm_factors.append(self.load_config.compute_hash()) else: vllm_factors.append("None") + if self.attention_config: + vllm_factors.append(self.attention_config.compute_hash()) + else: + vllm_factors.append("None") if self.lora_config: vllm_factors.append(self.lora_config.compute_hash()) else: diff --git a/vllm/engine/arg_utils.py b/vllm/engine/arg_utils.py index 883ae370f9e7..aad0719548d1 100644 --- a/vllm/engine/arg_utils.py +++ b/vllm/engine/arg_utils.py @@ -34,6 +34,7 @@ from typing_extensions import TypeIs import vllm.envs as envs from vllm.attention.backends.registry import AttentionBackendEnum from vllm.config import ( + AttentionConfig, CacheConfig, CompilationConfig, ConfigType, @@ -527,6 +528,7 @@ class EngineArgs: pooler_config: PoolerConfig | None = ModelConfig.pooler_config compilation_config: CompilationConfig = get_field(VllmConfig, "compilation_config") + attention_config: AttentionConfig = get_field(VllmConfig, "attention_config") worker_cls: str = ParallelConfig.worker_cls worker_extension_cls: str = ParallelConfig.worker_extension_cls @@ -542,6 +544,7 @@ class EngineArgs: ) model_impl: str = ModelConfig.model_impl override_attention_dtype: str = ModelConfig.override_attention_dtype + attention_backend: AttentionBackendEnum | None = AttentionConfig.backend calculate_kv_scales: bool = CacheConfig.calculate_kv_scales mamba_cache_dtype: MambaDType = CacheConfig.mamba_cache_dtype @@ -580,6 +583,8 @@ class EngineArgs: # CompilationConfig object if isinstance(self.compilation_config, dict): self.compilation_config = CompilationConfig(**self.compilation_config) + if isinstance(self.attention_config, dict): + self.attention_config = AttentionConfig(**self.attention_config) if isinstance(self.eplb_config, dict): self.eplb_config = EPLBConfig(**self.eplb_config) # Setup plugins @@ -717,6 +722,16 @@ class EngineArgs: "--pt-load-map-location", **load_kwargs["pt_load_map_location"] ) + # Attention arguments + attention_kwargs = get_kwargs(AttentionConfig) + attention_group = parser.add_argument_group( + title="AttentionConfig", + description=AttentionConfig.__doc__, + ) + attention_group.add_argument( + "--attention-backend", **attention_kwargs["backend"] + ) + # Structured outputs arguments structured_outputs_kwargs = get_kwargs(StructuredOutputsConfig) structured_outputs_group = parser.add_argument_group( @@ -1140,6 +1155,9 @@ class EngineArgs: vllm_group.add_argument( "--compilation-config", "-cc", **vllm_kwargs["compilation_config"] ) + vllm_group.add_argument( + "--attention-config", "-ac", **vllm_kwargs["attention_config"] + ) vllm_group.add_argument( "--additional-config", **vllm_kwargs["additional_config"] ) @@ -1693,6 +1711,16 @@ class EngineArgs: if model_config.quantization == "bitsandbytes": self.quantization = self.load_format = "bitsandbytes" + # Attention config overrides + attention_config = copy.deepcopy(self.attention_config) + if self.attention_backend is not None: + if attention_config.backend is not None: + raise ValueError( + "attention_backend and attention_config.backend " + "are mutually exclusive" + ) + attention_config.backend = self.attention_backend + load_config = self.create_load_config() # Pass reasoning_parser into StructuredOutputsConfig @@ -1750,9 +1778,10 @@ class EngineArgs: parallel_config=parallel_config, scheduler_config=scheduler_config, device_config=device_config, + load_config=load_config, + attention_config=attention_config, lora_config=lora_config, speculative_config=speculative_config, - load_config=load_config, structured_outputs_config=self.structured_outputs_config, observability_config=observability_config, compilation_config=compilation_config, diff --git a/vllm/model_executor/models/config.py b/vllm/model_executor/models/config.py index fbeb28a1c0b3..55dd6e50ad24 100644 --- a/vllm/model_executor/models/config.py +++ b/vllm/model_executor/models/config.py @@ -4,7 +4,7 @@ from copy import deepcopy from math import lcm from typing import TYPE_CHECKING -import vllm.envs as envs +from vllm.attention.backends.registry import AttentionBackendEnum from vllm.logger import init_logger from vllm.model_executor.models import ModelRegistry from vllm.platforms import current_platform @@ -331,6 +331,7 @@ class HybridAttentionMambaModelConfig(VerifyAndUpdateConfig): # Enable FULL_AND_PIECEWISE by default MambaModelConfig.verify_and_update_config(vllm_config) + attention_config = vllm_config.attention_config cache_config = vllm_config.cache_config model_config = vllm_config.model_config parallel_config = vllm_config.parallel_config @@ -347,7 +348,9 @@ class HybridAttentionMambaModelConfig(VerifyAndUpdateConfig): # * CUTLASS_MLA backend: kernel_block_size 128 alignment # * Other MLA backends: kernel_block_size 64 alignment if model_config.use_mla: - use_cutlass_mla = envs.VLLM_ATTENTION_BACKEND == "CUTLASS_MLA" + use_cutlass_mla = ( + attention_config.backend == AttentionBackendEnum.CUTLASS_MLA + ) kernel_block_alignment_size = 128 if use_cutlass_mla else 64 attn_page_size_1_token = MLAAttentionSpec( block_size=1, @@ -361,8 +364,8 @@ class HybridAttentionMambaModelConfig(VerifyAndUpdateConfig): current_platform.is_device_capability(100) and model_config.get_head_size() == 256 and ( - envs.VLLM_ATTENTION_BACKEND is None - or envs.VLLM_ATTENTION_BACKEND == "FLASHINFER" + attention_config.backend is None + or attention_config.backend == AttentionBackendEnum.FLASHINFER ) ): # https://github.com/flashinfer-ai/flashinfer/issues/1993 reports that` diff --git a/vllm/model_executor/models/vision.py b/vllm/model_executor/models/vision.py index e5d70eb7bc2f..7602eca9c325 100644 --- a/vllm/model_executor/models/vision.py +++ b/vllm/model_executor/models/vision.py @@ -11,7 +11,7 @@ import torch from transformers import PretrainedConfig from vllm.attention.backends.registry import AttentionBackendEnum -from vllm.config import VllmConfig +from vllm.config import VllmConfig, get_current_vllm_config from vllm.distributed import ( get_tensor_model_parallel_rank, get_tensor_model_parallel_world_size, @@ -91,10 +91,7 @@ def get_vit_attn_backend( if attn_backend_override is not None: return attn_backend_override - # Lazy import to avoid circular dependency - from vllm.attention.selector import get_env_variable_attn_backend - - selected_backend: AttentionBackendEnum | None = get_env_variable_attn_backend() + selected_backend = get_current_vllm_config().attention_config.backend if selected_backend is not None: return selected_backend diff --git a/vllm/platforms/cuda.py b/vllm/platforms/cuda.py index 1467ca71efec..7e6ce6aeef53 100644 --- a/vllm/platforms/cuda.py +++ b/vllm/platforms/cuda.py @@ -14,7 +14,6 @@ from typing_extensions import ParamSpec # import custom ops, trigger op registration import vllm._C # noqa -import vllm.envs as envs from vllm.attention.backends.abstract import AttentionType from vllm.attention.backends.registry import AttentionBackendEnum from vllm.logger import init_logger @@ -149,6 +148,8 @@ class CudaPlatformBase(Platform): @classmethod def check_and_update_config(cls, vllm_config: "VllmConfig") -> None: + from vllm.attention.backends.registry import AttentionBackendEnum + parallel_config = vllm_config.parallel_config model_config = vllm_config.model_config @@ -171,7 +172,7 @@ class CudaPlatformBase(Platform): and cache_config.block_size is not None ): use_sparse = hasattr(vllm_config.model_config.hf_config, "index_topk") - # If `VLLM_ATTENTION_BACKEND` is not set and we are using MLA, + # If `--attention-config.backend` is not set and we are using MLA, # then we default to FlashMLA backend for non-blackwell GPUs, # else we default to CutlassMLA. For each case, we force the # required block_size. @@ -179,23 +180,25 @@ class CudaPlatformBase(Platform): use_cutlass_mla = False use_flashinfer_mla = False - if envs.VLLM_ATTENTION_BACKEND is None: + if vllm_config.attention_config.backend is None: # Default case if cls.is_device_capability(100): # Blackwell => Force CutlassMLA. use_cutlass_mla = True - # TODO: This does not work, because the - # global_force_attn_backend_context_manager is not set. - # See vllm/attention/selector.py:_cached_get_attn_backend - envs.VLLM_ATTENTION_BACKEND = "CUTLASS_MLA" + # Set the backend in AttentionConfig so it's used during + # backend selection + vllm_config.attention_config.backend = ( + AttentionBackendEnum.CUTLASS_MLA + ) else: # Not Blackwell use_flashmla = True else: # Forced case - use_flashmla = envs.VLLM_ATTENTION_BACKEND == "FLASHMLA" - use_cutlass_mla = envs.VLLM_ATTENTION_BACKEND == "CUTLASS_MLA" - use_flashinfer_mla = envs.VLLM_ATTENTION_BACKEND == "FLASHINFER_MLA" + backend = vllm_config.attention_config.backend + use_flashmla = backend == AttentionBackendEnum.FLASHMLA + use_cutlass_mla = backend == AttentionBackendEnum.CUTLASS_MLA + use_flashinfer_mla = backend == AttentionBackendEnum.FLASHINFER_MLA from vllm.attention.ops.flashmla import is_flashmla_dense_supported diff --git a/vllm/utils/flashinfer.py b/vllm/utils/flashinfer.py index 9f9976d52b4a..7aaf690cbaa1 100644 --- a/vllm/utils/flashinfer.py +++ b/vllm/utils/flashinfer.py @@ -267,21 +267,16 @@ def supports_trtllm_attention() -> bool: return current_platform.is_device_capability(100) and has_nvidia_artifactory() -@functools.cache -def _force_use_trtllm_attention(env_value: bool | None) -> bool | None: - """Cache the env value for VLLM_USE_TRTLLM_ATTENTION""" - if env_value is not None: - logger.info_once("VLLM_USE_TRTLLM_ATTENTION is set to %s", env_value) - return env_value - - def force_use_trtllm_attention() -> bool | None: """ - Return `None` if VLLM_USE_TRTLLM_ATTENTION is not set, + Return `None` if --attention-config.use_trtllm_attention is not set, return `True` if TRTLLM attention is forced to be used, return `False` if TRTLLM attention is forced to be not used. """ - return _force_use_trtllm_attention(envs.VLLM_USE_TRTLLM_ATTENTION) + from vllm.config import get_current_vllm_config + + vllm_config = get_current_vllm_config() + return vllm_config.attention_config.use_trtllm_attention def can_use_trtllm_attention(num_qo_heads: int, num_kv_heads: int) -> bool: @@ -307,7 +302,7 @@ def use_trtllm_attention( """Return `True` if TRTLLM attention is used.""" force_use_trtllm = force_use_trtllm_attention() - # Environment variable is set to 0 - respect it + # CLI argument is set to 0 - respect it if force_use_trtllm is not None and not force_use_trtllm: return False @@ -324,7 +319,7 @@ def use_trtllm_attention( if force_use_trtllm: logger.warning_once( "TRTLLM attention is not supported on this platform, " - "but VLLM_USE_TRTLLM_ATTENTION is set to 1" + "but --attention-config.use_trtllm_attention is set to 1" ) return False @@ -333,7 +328,8 @@ def use_trtllm_attention( if force_use_trtllm: logger.warning_once( "TRTLLM attention is not supported for this combination of " - "query and key heads, but VLLM_USE_TRTLLM_ATTENTION is set to 1" + "query and key heads, but --attention-config.use_trtllm_attention is " + "set to 1" ) return False @@ -354,7 +350,7 @@ def use_trtllm_attention( return True if force_use_trtllm is None: - # Environment variable not set - use auto-detection + # CLI argument not set - use auto-detection if is_prefill: # Prefill auto-detection use_trtllm = kv_cache_dtype == "auto" @@ -367,8 +363,10 @@ def use_trtllm_attention( logger.warning_once("Using TRTLLM decode attention (auto-detected).") return use_trtllm - # Environment variable is set to 1 - respect it - logger.info_once("Using TRTLLM attention (VLLM_USE_TRTLLM_ATTENTION is set to 1)") + # CLI argument is set to 1 - respect it + logger.info_once( + "Using TRTLLM attention (--attention-config.use_trtllm_attention is set to 1)" + ) return True @@ -500,12 +498,6 @@ def flashinfer_scaled_fp8_mm( return output -@functools.cache -def flashinfer_disable_q_quantization() -> bool: - """Cache result which only depends on the environment""" - return envs.VLLM_FLASHINFER_DISABLE_Q_QUANTIZATION - - __all__ = [ "has_flashinfer", "flashinfer_trtllm_fp8_block_scale_moe", @@ -526,7 +518,6 @@ __all__ = [ "supports_trtllm_attention", "can_use_trtllm_attention", "use_trtllm_attention", - "flashinfer_disable_q_quantization", "flashinfer_scaled_fp4_mm", "flashinfer_scaled_fp8_mm", ] diff --git a/vllm/v1/attention/backends/flash_attn.py b/vllm/v1/attention/backends/flash_attn.py index fb080b0b33bc..f5ad98cf2125 100755 --- a/vllm/v1/attention/backends/flash_attn.py +++ b/vllm/v1/attention/backends/flash_attn.py @@ -8,7 +8,6 @@ from typing import ClassVar import numpy as np import torch -from vllm import envs from vllm.attention.backends.abstract import ( AttentionBackend, AttentionImpl, @@ -264,6 +263,7 @@ class FlashAttentionMetadataBuilder(AttentionMetadataBuilder[FlashAttentionMetad self.parallel_config = vllm_config.parallel_config self.cache_config = vllm_config.cache_config self.compilation_config = vllm_config.compilation_config + self.attention_config = vllm_config.attention_config self.num_heads_q = self.model_config.get_num_attention_heads( self.parallel_config @@ -304,7 +304,9 @@ class FlashAttentionMetadataBuilder(AttentionMetadataBuilder[FlashAttentionMetad # When using cuda graph, we need to set the upper bound of the # number of splits so that large enough intermediate buffers are # pre-allocated during capture. - self.max_num_splits = envs.VLLM_FLASH_ATTN_MAX_NUM_SPLITS_FOR_CUDA_GRAPH + self.max_num_splits = ( + self.attention_config.flash_attn_max_num_splits_for_cuda_graph + ) # Sliding window size to be used with the AOT scheduler will be # populated on first build() call. @@ -554,8 +556,7 @@ class FlashAttentionImpl(AttentionImpl): "heads in the layer" ) - def supports_quant_query_input(self) -> bool: - return True + self.supports_quant_query_input = True def forward( self, diff --git a/vllm/v1/attention/backends/flashinfer.py b/vllm/v1/attention/backends/flashinfer.py index 3d9640a2d402..8e9d764e4a12 100755 --- a/vllm/v1/attention/backends/flashinfer.py +++ b/vllm/v1/attention/backends/flashinfer.py @@ -26,7 +26,7 @@ from vllm.attention.backends.abstract import ( ) from vllm.attention.ops.common import cp_lse_ag_out_rs from vllm.attention.ops.merge_attn_states import merge_attn_states -from vllm.config import CUDAGraphMode, VllmConfig +from vllm.config import CUDAGraphMode, VllmConfig, get_current_vllm_config from vllm.config.cache import CacheDType from vllm.distributed.parallel_state import get_dcp_group from vllm.logger import init_logger @@ -43,7 +43,6 @@ from vllm.platforms.interface import DeviceCapability from vllm.triton_utils import tl, triton from vllm.utils.flashinfer import ( can_use_trtllm_attention, - flashinfer_disable_q_quantization, use_trtllm_attention, ) from vllm.utils.math_utils import cdiv @@ -362,7 +361,8 @@ class FlashInferBackend(AttentionBackend): supports_trtllm_attention, ) - # Respect explicit disable flag (e.g., VLLM_USE_TRTLLM_ATTENTION=0) + # Respect explicit disable flag (e.g., + # --attention-config.use_trtllm_attention=0) if force_use_trtllm_attention() is False: return False @@ -500,11 +500,14 @@ class FlashInferMetadataBuilder(AttentionMetadataBuilder[FlashInferMetadata]): self.kv_cache_dtype = self.kv_cache_spec.dtype # Use model dtype as q dtype when TRTLLM attn is not supported, or - # VLLM_FLASHINFER_DISABLE_Q_QUANTIZATION is set to 1. Otherwise, try to - # use fp8 q if kv cache is fp8, and will fall back to model dtype + # --attention-config.disable_flashinfer_q_quantization is set to 1. Otherwise, + # try to use fp8 q if kv cache is fp8, and will fall back to model dtype # if TRTLLM attention kernel is not used when building attn metadata can_use_trtllm = can_use_trtllm_attention(self.num_qo_heads, self.num_kv_heads) - if can_use_trtllm and not flashinfer_disable_q_quantization(): + if ( + can_use_trtllm + and not vllm_config.attention_config.disable_flashinfer_q_quantization + ): self.q_data_type = self.kv_cache_dtype else: self.q_data_type = self.model_config.dtype @@ -1035,6 +1038,11 @@ class FlashInferImpl(AttentionImpl): self.sinks = sinks self.support_trtllm_attn = can_use_trtllm_attention(num_heads, num_kv_heads) + vllm_config = get_current_vllm_config() + self.supports_quant_query_input = ( + self.support_trtllm_attn + and not vllm_config.attention_config.disable_flashinfer_q_quantization + ) self.bmm1_scale: float | None = None self.bmm2_scale: float | None = None self.o_sf_scale: float | None = None @@ -1046,12 +1054,6 @@ class FlashInferImpl(AttentionImpl): and quant_key in (kFp8StaticTensorSym, kNvfp4Quant) ) - def supports_quant_query_input(self) -> bool: - if flashinfer_disable_q_quantization(): - return False - - return self.support_trtllm_attn - # FlashInfer requires attention sinks to be float32 def process_weights_after_loading(self, act_dtype: torch.dtype): if self.sinks is not None and self.sinks.dtype != torch.float32: diff --git a/vllm/v1/attention/backends/mla/common.py b/vllm/v1/attention/backends/mla/common.py index 180625b6ce89..309ddee4fc2f 100755 --- a/vllm/v1/attention/backends/mla/common.py +++ b/vllm/v1/attention/backends/mla/common.py @@ -438,19 +438,25 @@ A = TypeVar("A") def use_flashinfer_prefill() -> bool: # For blackwell default to flashinfer prefill if it's available since # it is faster than FA2. + from vllm.config import get_current_vllm_config + + vllm_config = get_current_vllm_config() return ( - not envs.VLLM_DISABLE_FLASHINFER_PREFILL + not vllm_config.attention_config.disable_flashinfer_prefill and flashinfer_available - and not envs.VLLM_USE_CUDNN_PREFILL - and not envs.VLLM_USE_TRTLLM_RAGGED_DEEPSEEK_PREFILL + and not vllm_config.attention_config.use_cudnn_prefill + and not vllm_config.attention_config.use_trtllm_ragged_deepseek_prefill and current_platform.is_device_capability(100) ) def use_cudnn_prefill() -> bool: + from vllm.config import get_current_vllm_config + + vllm_config = get_current_vllm_config() return ( flashinfer_available - and envs.VLLM_USE_CUDNN_PREFILL + and vllm_config.attention_config.use_cudnn_prefill and current_platform.is_device_capability(100) and has_nvidia_artifactory() ) @@ -458,9 +464,12 @@ def use_cudnn_prefill() -> bool: def use_trtllm_ragged_deepseek_prefill() -> bool: """Check if TRT-LLM ragged DeepSeek prefill should be used.""" + from vllm.config import get_current_vllm_config + + vllm_config = get_current_vllm_config() return ( flashinfer_available - and envs.VLLM_USE_TRTLLM_RAGGED_DEEPSEEK_PREFILL + and vllm_config.attention_config.use_trtllm_ragged_deepseek_prefill and current_platform.is_device_capability(100) ) diff --git a/vllm/v1/attention/backends/mla/flashattn_mla.py b/vllm/v1/attention/backends/mla/flashattn_mla.py index d369814c10b6..eccf4ec79109 100644 --- a/vllm/v1/attention/backends/mla/flashattn_mla.py +++ b/vllm/v1/attention/backends/mla/flashattn_mla.py @@ -6,7 +6,6 @@ from typing import ClassVar import torch -from vllm import envs from vllm.attention.backends.abstract import ( AttentionLayer, AttentionType, @@ -131,7 +130,9 @@ class FlashAttnMLAMetadataBuilder(MLACommonMetadataBuilder[FlashAttnMLAMetadata] # When using cuda graph, we need to set the upper bound of the # number of splits so that large enough intermediate buffers are # pre-allocated during capture. - self.max_num_splits = envs.VLLM_FLASH_ATTN_MAX_NUM_SPLITS_FOR_CUDA_GRAPH + self.max_num_splits = ( + vllm_config.attention_config.flash_attn_max_num_splits_for_cuda_graph + ) if vllm_is_batch_invariant(): self.max_num_splits = 1 diff --git a/vllm/v1/attention/backends/rocm_attn.py b/vllm/v1/attention/backends/rocm_attn.py index 868143cc192e..e2410a70b1a6 100644 --- a/vllm/v1/attention/backends/rocm_attn.py +++ b/vllm/v1/attention/backends/rocm_attn.py @@ -165,7 +165,7 @@ class RocmAttentionBackend(AttentionBackend): raise ValueError( f"Head size {head_size} is not supported by {attn_type}. " f"Supported head sizes are: {cls.get_supported_head_sizes()}. " - "Set VLLM_ATTENTION_BACKEND=FLEX_ATTENTION to use " + "Set --attention-config.backend=FLEX_ATTENTION to use " "FlexAttention backend which supports all head sizes." ) diff --git a/vllm/v1/attention/backends/triton_attn.py b/vllm/v1/attention/backends/triton_attn.py index d051a89f03bb..3b17c4bcd89c 100644 --- a/vllm/v1/attention/backends/triton_attn.py +++ b/vllm/v1/attention/backends/triton_attn.py @@ -210,9 +210,6 @@ class TritonAttentionImpl(AttentionImpl): def fused_output_quant_supported(self, quant_key: QuantKey): return quant_key == kFp8StaticTensorSym - def supports_quant_query_input(self) -> bool: - return current_platform.is_cuda() - def __init__( self, num_heads: int, @@ -262,6 +259,8 @@ class TritonAttentionImpl(AttentionImpl): f"num_heads: {num_heads}." ) + self.supports_quant_query_input = current_platform.is_cuda() + def forward( self, layer: torch.nn.Module,