diff --git a/vllm/attention/ops/chunked_prefill_paged_decode.py b/vllm/attention/ops/chunked_prefill_paged_decode.py index 4f839348e5222..08bfcc974cc95 100644 --- a/vllm/attention/ops/chunked_prefill_paged_decode.py +++ b/vllm/attention/ops/chunked_prefill_paged_decode.py @@ -28,6 +28,7 @@ def kernel_paged_attention_2d( query_ptr, # [num_tokens, num_query_heads, head_size] key_cache_ptr, # [num_blks, num_kv_heads, head_size // x, blk_size, x] value_cache_ptr, # [num_blks, num_kv_heads, head_size, blk_size] + sink_ptr, # [num_query_heads] block_tables_ptr, # [num_seqs, max_num_blocks_per_seq] seq_lens_ptr, # [num_seqs] alibi_slopes_ptr, # [num_query_heads] @@ -95,7 +96,17 @@ def kernel_paged_attention_2d( block_table_offset = seq_idx * block_table_stride - M = tl.full([num_queries_per_kv_padded], float("-inf"), dtype=tl.float32) + if sink_ptr is None: + M = tl.full([num_queries_per_kv_padded], + float("-inf"), + dtype=tl.float32) + else: + M = tl.load( + sink_ptr + query_head_idx, + mask=head_mask, + other=float("-inf"), + ).to(dtype=tl.float32) + L = tl.full([num_queries_per_kv_padded], 1.0, dtype=tl.float32) acc = tl.zeros([num_queries_per_kv_padded, HEAD_SIZE_PADDED], dtype=tl.float32) @@ -223,6 +234,8 @@ def chunked_prefill_paged_decode( alibi_slopes=None, sliding_window=None, sm_scale=None, + # Optional tensor for sinks + sinks=None, ): if sm_scale is None: @@ -253,6 +266,7 @@ def chunked_prefill_paged_decode( sliding_window=sliding_window, sm_scale=sm_scale, skip_decode=True, + sinks=sinks, ) block_size = value_cache.shape[3] @@ -281,11 +295,17 @@ def chunked_prefill_paged_decode( num_queries_per_kv_padded = max(triton.next_power_of_2(num_queries_per_kv), 16) - use_custom = use_rocm_custom_paged_attention(query.dtype, head_size, - block_size, - num_queries_per_kv, - max_seq_len, sliding_window, - kv_cache_dtype, alibi_slopes) + use_custom = use_rocm_custom_paged_attention( + query.dtype, + head_size, + block_size, + num_queries_per_kv, + max_seq_len, + sliding_window, + kv_cache_dtype, + alibi_slopes, + sinks, + ) if use_custom: _PARTITION_SIZE_ROCM = 256 max_num_partitions = ((max_seq_len + _PARTITION_SIZE_ROCM - 1) // @@ -334,6 +354,7 @@ def chunked_prefill_paged_decode( query_ptr=query, key_cache_ptr=key_cache, value_cache_ptr=value_cache, + sink_ptr=sinks, block_tables_ptr=block_table, seq_lens_ptr=seq_lens, alibi_slopes_ptr=alibi_slopes, diff --git a/vllm/attention/ops/prefix_prefill.py b/vllm/attention/ops/prefix_prefill.py index 13bef96722d2b..64c90337970f7 100644 --- a/vllm/attention/ops/prefix_prefill.py +++ b/vllm/attention/ops/prefix_prefill.py @@ -38,6 +38,7 @@ def _fwd_kernel(Q, V, K_cache, V_cache, + sink_ptr, B_Loc, sm_scale, k_scale, @@ -126,7 +127,15 @@ def _fwd_kernel(Q, other=0.0) # [M,D] # initialize pointer to m and l - m_i = tl.full([BLOCK_M], float("-inf"), dtype=tl.float32) + if sink_ptr is None: + m_i = tl.full([BLOCK_M], float("-inf"), dtype=tl.float32) + else: + m_i = tl.load( + sink_ptr + tl.full([BLOCK_M], cur_head, dtype=tl.int64), + mask=(offs_m < cur_batch_query_len), + other=float("-inf"), + ).to(dtype=tl.float32) + l_i = tl.full([BLOCK_M], 1.0, dtype=tl.float32) acc = tl.zeros([BLOCK_M, BLOCK_DMODEL_PADDED], dtype=tl.float32) # [M,D] @@ -732,7 +741,8 @@ def context_attention_fwd(q, alibi_slopes=None, sliding_window=None, sm_scale=None, - skip_decode=False): + skip_decode=False, + sinks=None): q_dtype_is_f32 = q.dtype is torch.float32 @@ -781,6 +791,7 @@ def context_attention_fwd(q, sliding_window = 0 if alibi_slopes is not None: + assert sinks is None, "Sinks arg is not supported with alibi" # need to reduce num. blocks when using fp32 # due to increased use of GPU shared memory # if q.dtype is torch.float32: @@ -843,7 +854,7 @@ def context_attention_fwd(q, max_seq_len = 0 if max_seq_len is None else max_seq_len extra_kargs = {} if current_platform.is_rocm(): - extra_kargs = {"kpack": 2, "waves_per_eu": 2} + extra_kargs = {"kpack": 1, "waves_per_eu": 2} grid = lambda META: (batch, head, triton.cdiv(max_input_len, META["BLOCK_M"])) @@ -853,6 +864,7 @@ def context_attention_fwd(q, v, k_cache, v_cache, + sinks, b_loc, sm_scale, k_scale, diff --git a/vllm/attention/ops/triton_unified_attention.py b/vllm/attention/ops/triton_unified_attention.py index 0fdba569f93f2..ba4299a2772df 100644 --- a/vllm/attention/ops/triton_unified_attention.py +++ b/vllm/attention/ops/triton_unified_attention.py @@ -52,6 +52,7 @@ def kernel_unified_attention_2d( query_ptr, # [num_tokens, num_query_heads, head_size] key_cache_ptr, # [num_blks, blk_size, num_kv_heads, head_size] value_cache_ptr, # [num_blks, blk_size, num_kv_heads, head_size] + sink_ptr, # [num_query_heads] block_tables_ptr, # [num_seqs, max_num_blocks_per_seq] seq_lens_ptr, # [num_seqs] alibi_slopes_ptr, # [num_query_heads] @@ -131,7 +132,15 @@ def kernel_unified_attention_2d( block_table_offset = seq_idx * block_table_stride - M = tl.full([BLOCK_M], float("-inf"), dtype=tl.float32) + if sink_ptr is None: + M = tl.full([BLOCK_M], float("-inf"), dtype=tl.float32) + else: + M = tl.load( + sink_ptr + query_offset_1, + mask=query_mask_1, + other=float("-inf"), + ).to(dtype=tl.float32) + L = tl.full([BLOCK_M], 1.0, dtype=tl.float32) acc = tl.zeros([BLOCK_M, HEAD_SIZE_PADDED], dtype=tl.float32) @@ -292,6 +301,7 @@ def kernel_unified_attention_3d( query_ptr, # [num_tokens, num_query_heads, head_size] key_cache_ptr, # [num_blks, num_kv_heads, head_size // x, blk_size, x] value_cache_ptr, # [num_blks, num_kv_heads, head_size, blk_size] + sink_ptr, # [num_query_heads] block_tables_ptr, # [num_seqs, max_num_blocks_per_seq] seq_lens_ptr, # [num_seqs] alibi_slopes_ptr, # [num_query_heads] @@ -383,7 +393,15 @@ def kernel_unified_attention_3d( block_table_offset = seq_idx * block_table_stride - M = tl.full([BLOCK_M], float("-inf"), dtype=tl.float32) + if sink_ptr is None or segm_idx != 0: + M = tl.full([BLOCK_M], float("-inf"), dtype=tl.float32) + else: + M = tl.load( + sink_ptr + query_offset_1, + mask=query_mask_1, + other=float("-inf"), + ).to(dtype=tl.float32) + L = tl.full([BLOCK_M], 1.0, dtype=tl.float32) acc = tl.zeros([BLOCK_M, HEAD_SIZE_PADDED], dtype=tl.float32) @@ -627,6 +645,8 @@ def unified_attention( v_descale, alibi_slopes=None, qq_bias=None, + # Optional tensor for sinks + sinks=None, ): assert causal, "Only causal attention is supported" assert q_descale is None, "Q scales not supported" @@ -635,6 +655,10 @@ def unified_attention( assert q.element_size() >= 2 or block_size >= 32, \ "Block size must be at least 32 for fp8" + if sinks is not None: + assert sinks.shape[0] == q.shape[1], \ + "Sinks must be num_query_heads size" + use_alibi_slopes = alibi_slopes is not None use_qq_bias = qq_bias is not None @@ -669,6 +693,7 @@ def unified_attention( query_ptr=q, key_cache_ptr=k, value_cache_ptr=v, + sink_ptr=sinks, block_tables_ptr=block_table, seq_lens_ptr=seqused_k, alibi_slopes_ptr=alibi_slopes, @@ -741,6 +766,7 @@ def unified_attention( query_ptr=q, key_cache_ptr=k, value_cache_ptr=v, + sink_ptr=sinks, block_tables_ptr=block_table, seq_lens_ptr=seqused_k, alibi_slopes_ptr=alibi_slopes, diff --git a/vllm/envs.py b/vllm/envs.py index e28e9658e5b53..f8a7197dd1bba 100755 --- a/vllm/envs.py +++ b/vllm/envs.py @@ -17,6 +17,7 @@ if TYPE_CHECKING: LD_LIBRARY_PATH: Optional[str] = None VLLM_USE_TRITON_FLASH_ATTN: bool = True VLLM_V1_USE_PREFILL_DECODE_ATTENTION: bool = False + VLLM_USE_AITER_UNIFIED_ATTENTION: bool = False VLLM_FLASH_ATTN_VERSION: Optional[int] = None LOCAL_RANK: int = 0 CUDA_VISIBLE_DEVICES: Optional[str] = None @@ -151,6 +152,8 @@ if TYPE_CHECKING: VLLM_LOOPBACK_IP: str = "" VLLM_ALLOW_CHUNKED_LOCAL_ATTN_WITH_HYBRID_KV_CACHE: bool = False VLLM_ENABLE_RESPONSES_API_STORE: bool = False + VLLM_USE_TRTLLM_CONTEXT_ATTENTION: bool = False + VLLM_USE_TRTLLM_DECODE_ATTENTION: bool = False def get_default_cache_root(): @@ -326,6 +329,12 @@ environment_variables: dict[str, Callable[[], Any]] = { (os.getenv("VLLM_V1_USE_PREFILL_DECODE_ATTENTION", "False").lower() in ("true", "1")), + # Use AITER triton unified attention for V1 attention + "VLLM_USE_AITER_UNIFIED_ATTENTION": + lambda: + (os.getenv("VLLM_USE_AITER_UNIFIED_ATTENTION", "False").lower() in + ("true", "1")), + # Force vllm to use a specific flash-attention version (2 or 3), only valid # when using the flash-attention backend. "VLLM_FLASH_ATTN_VERSION": @@ -1022,9 +1031,13 @@ environment_variables: dict[str, Callable[[], Any]] = { "VLLM_USE_CUDNN_PREFILL": lambda: bool(int(os.getenv("VLLM_USE_CUDNN_PREFILL", "0"))), - # If set to 1, use the TRTLLM Attention backend in flashinfer. - "VLLM_USE_TRTLLM_ATTENTION": - lambda: os.getenv("VLLM_USE_TRTLLM_ATTENTION", None), + # If set to 1, use the TRTLLM Context Attention backend in flashinfer. + "VLLM_USE_TRTLLM_CONTEXT_ATTENTION": + lambda: bool(int(os.getenv("VLLM_USE_TRTLLM_CONTEXT_ATTENTION", "0"))), + + # If set to 1, use the TRTLLM Decode Attention backend in flashinfer. + "VLLM_USE_TRTLLM_DECODE_ATTENTION": + lambda: bool(int(os.getenv("VLLM_USE_TRTLLM_DECODE_ATTENTION", "0"))), # Controls garbage collection during CUDA graph capture. # If set to 0 (default), enables GC freezing to speed up capture time. diff --git a/vllm/v1/attention/backends/flash_attn.py b/vllm/v1/attention/backends/flash_attn.py index f086bab2556eb..95ba56b359379 100755 --- a/vllm/v1/attention/backends/flash_attn.py +++ b/vllm/v1/attention/backends/flash_attn.py @@ -373,6 +373,7 @@ class FlashAttentionImpl(AttentionImpl): logits_soft_cap: Optional[float] = None, attn_type: AttentionType = AttentionType.DECODER, kv_sharing_target_layer_name: Optional[str] = None, + sinks: Optional[torch.Tensor] = None, ) -> None: self.num_heads = num_heads self.head_size = head_size @@ -410,6 +411,14 @@ class FlashAttentionImpl(AttentionImpl): raise NotImplementedError( "FlashAttention does not support fp8 kv-cache on this device.") + self.sinks = sinks + if self.sinks is not None: + assert self.vllm_flash_attn_version == 3, ( + "Sinks are only supported in FlashAttention 3") + assert self.sinks.shape[0] == num_heads, ( + "Sinks must have the same number of heads as the number of " + "heads in the layer") + def forward( self, layer: torch.nn.Module, @@ -534,6 +543,7 @@ class FlashAttentionImpl(AttentionImpl): k_descale=layer._k_scale.expand(descale_shape), v_descale=layer._v_scale.expand(descale_shape), num_splits=attn_metadata.max_num_splits, + s_aux=self.sinks, ) return output diff --git a/vllm/v1/attention/backends/triton_attn.py b/vllm/v1/attention/backends/triton_attn.py index 942cb95eefa2f..c33afbfebcde2 100644 --- a/vllm/v1/attention/backends/triton_attn.py +++ b/vllm/v1/attention/backends/triton_attn.py @@ -2,6 +2,7 @@ # SPDX-FileCopyrightText: Copyright contributors to the vLLM project """Attention layer with PagedAttention and Triton prefix prefill.""" from dataclasses import dataclass +from functools import cache from typing import ClassVar, Optional import torch @@ -13,7 +14,6 @@ from vllm.attention.backends.abstract import (AttentionBackend, AttentionImpl, from vllm.attention.ops.chunked_prefill_paged_decode import ( chunked_prefill_paged_decode) from vllm.attention.ops.paged_attn import PagedAttention -from vllm.attention.ops.triton_unified_attention import unified_attention from vllm.config import VllmConfig from vllm.logger import init_logger from vllm.platforms import current_platform @@ -193,6 +193,15 @@ class TritonAttentionBackend(AttentionBackend): return TritonAttentionMetadataBuilder +@cache +def use_aiter_unified_attention() -> bool: + """Check if aiter unified attention should be used.""" + # VLLM_ROCM_USE_AITER_MHA needs to set to 0 as well as it is set + # to 1 as default + return envs.VLLM_ROCM_USE_AITER \ + and envs.VLLM_USE_AITER_UNIFIED_ATTENTION + + class TritonAttentionImpl(AttentionImpl): def __init__( @@ -207,6 +216,7 @@ class TritonAttentionImpl(AttentionImpl): logits_soft_cap: Optional[float] = None, attn_type: AttentionType = AttentionType.DECODER, kv_sharing_target_layer_name: Optional[int] = None, + sinks: Optional[torch.Tensor] = None, ) -> None: self.num_heads = num_heads self.head_size = head_size @@ -240,6 +250,29 @@ class TritonAttentionImpl(AttentionImpl): self.force_prefill_decode_attn = \ envs.VLLM_V1_USE_PREFILL_DECODE_ATTENTION + if not self.force_prefill_decode_attn: + # If not using prefill decode attention, we use the Triton + # unified attention implementation. + if use_aiter_unified_attention(): + logger.info_once( + "Using aiter unified attention for TritonAttentionImpl") + from aiter.ops.triton.unified_attention import ( + unified_attention) + self.unified_attention = unified_attention + else: + logger.info_once( + "Using vllm unified attention for TritonAttentionImpl") + from vllm.attention.ops.triton_unified_attention import ( + unified_attention) + self.unified_attention = unified_attention + + self.sinks = sinks + if sinks is not None: + assert sinks.shape[0] == num_heads, ( + "Sinks must have the same number of heads as the number of " + f"heads in the layer. Sinks shape: {sinks.shape}, " + f"num_heads: {num_heads}.") + def forward( self, layer: torch.nn.Module, @@ -342,28 +375,31 @@ class TritonAttentionImpl(AttentionImpl): if use_prefill_decode_attn: # Compute attention and update output up to `num_actual_tokens`. - chunked_prefill_paged_decode(query=query[:num_actual_tokens], - key=key[:num_actual_tokens], - value=value[:num_actual_tokens], - output=output[:num_actual_tokens], - kv_cache_dtype=self.kv_cache_dtype, - key_cache=key_cache, - value_cache=value_cache, - block_table=block_table, - query_start_loc=cu_seqlens_q, - seq_lens=seqused_k, - max_seq_len=max_seqlen_k, - max_query_len=max_seqlen_q, - k_scale=layer._k_scale, - v_scale=layer._v_scale, - alibi_slopes=self.alibi_slopes, - sliding_window=self.sliding_window[0], - sm_scale=self.scale) + chunked_prefill_paged_decode( + query=query[:num_actual_tokens], + key=key[:num_actual_tokens], + value=value[:num_actual_tokens], + output=output[:num_actual_tokens], + kv_cache_dtype=self.kv_cache_dtype, + key_cache=key_cache, + value_cache=value_cache, + block_table=block_table, + query_start_loc=cu_seqlens_q, + seq_lens=seqused_k, + max_seq_len=max_seqlen_k, + max_query_len=max_seqlen_q, + k_scale=layer._k_scale, + v_scale=layer._v_scale, + alibi_slopes=self.alibi_slopes, + sliding_window=self.sliding_window[0], + sm_scale=self.scale, + sinks=self.sinks, + ) else: descale_shape = (cu_seqlens_q.shape[0] - 1, key.shape[1]) - unified_attention( + self.unified_attention( q=query[:num_actual_tokens], k=key_cache, v=value_cache, @@ -381,6 +417,7 @@ class TritonAttentionImpl(AttentionImpl): q_descale=None, # Not supported k_descale=layer._k_scale.expand(descale_shape), v_descale=layer._v_scale.expand(descale_shape), + sinks=self.sinks, ) return output diff --git a/vllm/v1/attention/backends/utils.py b/vllm/v1/attention/backends/utils.py index 7aeea40b25a67..f521d94331b5e 100644 --- a/vllm/v1/attention/backends/utils.py +++ b/vllm/v1/attention/backends/utils.py @@ -254,7 +254,11 @@ def get_kv_cache_layout(): # Override with format specified by the user. cache_layout = envs.VLLM_KV_CACHE_LAYOUT if cache_layout is None: - cache_layout = get_kv_connector_cache_layout() + if (envs.VLLM_USE_TRTLLM_CONTEXT_ATTENTION + or envs.VLLM_USE_TRTLLM_DECODE_ATTENTION): + cache_layout = "HND" + else: + cache_layout = get_kv_connector_cache_layout() else: logger.info_once("`VLLM_KV_CACHE_LAYOUT` environment variable " \ "detected. Setting KV cache layout to %s.", cache_layout) @@ -272,7 +276,9 @@ def set_kv_cache_layout(cache_layout: str): class PerLayerParameters: """ Currently, FlashInfer backend only support models in which all layers share - the same values for the following hyperparameters. + the same values for the following hyperparameters. Should not be used for + trtllm-gen backend since it supports different values for the following + hyperparameters. """ window_left: int @@ -310,7 +316,8 @@ def get_per_layer_parameters( def infer_global_hyperparameters( per_layer_params: dict[str, PerLayerParameters]) -> PerLayerParameters: """ - Currently, FlashInfer backend only support models in which all layers share + Currently, FlashInfer backend other than trtllm-gen + only support models in which all layers share the same values for the following hyperparameters: - `window_left` - `logits_soft_cap` @@ -324,15 +331,20 @@ def infer_global_hyperparameters( param_sets = list(per_layer_params.values()) global_params = param_sets[0] - for params in param_sets: - if params.window_left != global_params.window_left: - raise ValueError( - "Window left is not the same for all layers. One potential fix " - "is to set disable_sliding_window=True") - assert params == global_params, ( - "FlashInfer backend currently only supports models in which all " - "layers share the same values for the following hyperparameters: " - "`window_left`, `logits_soft_cap`, `sm_scale`.") + + # trtllm attention doesn't need global hyper params so disable the check + if (not envs.VLLM_USE_TRTLLM_CONTEXT_ATTENTION + and not envs.VLLM_USE_TRTLLM_DECODE_ATTENTION): + for params in param_sets: + if params.window_left != global_params.window_left: + raise ValueError( + "Window left is not the same for all layers. " \ + "One potential fix is to set disable_sliding_window=True") + assert params == global_params, ( + "FlashInfer backend currently only supports models in which all" + "layers share the same values " + "for the following hyperparameters:" + "`window_left`, `logits_soft_cap`, `sm_scale`.") return global_params