diff --git a/vllm/attention/ops/triton_unified_attention.py b/vllm/attention/ops/triton_unified_attention.py index ba4299a2772df..56ebed0f52448 100644 --- a/vllm/attention/ops/triton_unified_attention.py +++ b/vllm/attention/ops/triton_unified_attention.py @@ -75,6 +75,7 @@ def kernel_unified_attention_2d( USE_ALIBI_SLOPES: tl.constexpr, # bool USE_QQ_BIAS: tl.constexpr, # bool USE_SOFTCAP: tl.constexpr, # bool + USE_SINKS: tl.constexpr, # bool SLIDING_WINDOW: tl.constexpr, # int stride_k_cache_0: tl.int64, # int stride_k_cache_1: tl.int64, # int @@ -132,7 +133,7 @@ def kernel_unified_attention_2d( block_table_offset = seq_idx * block_table_stride - if sink_ptr is None: + if not USE_SINKS: M = tl.full([BLOCK_M], float("-inf"), dtype=tl.float32) else: M = tl.load( @@ -322,6 +323,7 @@ def kernel_unified_attention_3d( USE_ALIBI_SLOPES: tl.constexpr, # bool USE_QQ_BIAS: tl.constexpr, # bool USE_SOFTCAP: tl.constexpr, # bool + USE_SINKS: tl.constexpr, # bool SLIDING_WINDOW: tl.constexpr, # int stride_k_cache_0: tl.int64, # int stride_k_cache_1: tl.int64, # int @@ -393,14 +395,17 @@ def kernel_unified_attention_3d( block_table_offset = seq_idx * block_table_stride - if sink_ptr is None or segm_idx != 0: - M = tl.full([BLOCK_M], float("-inf"), dtype=tl.float32) + if USE_SINKS: + if segm_idx == 0: + M = tl.load( + sink_ptr + query_offset_1, + mask=query_mask_1, + other=float("-inf"), + ).to(dtype=tl.float32) + else: + M = tl.full([BLOCK_M], float("-inf"), dtype=tl.float32) else: - M = tl.load( - sink_ptr + query_offset_1, - mask=query_mask_1, - other=float("-inf"), - ).to(dtype=tl.float32) + M = tl.full([BLOCK_M], float("-inf"), dtype=tl.float32) L = tl.full([BLOCK_M], 1.0, dtype=tl.float32) acc = tl.zeros([BLOCK_M, HEAD_SIZE_PADDED], dtype=tl.float32) @@ -716,6 +721,7 @@ def unified_attention( USE_ALIBI_SLOPES=use_alibi_slopes, USE_QQ_BIAS=use_qq_bias, USE_SOFTCAP=(softcap > 0), + USE_SINKS=(sinks is not None), SLIDING_WINDOW=(1 + window_size[0]), stride_k_cache_0=k.stride(0), stride_k_cache_1=k.stride(1), @@ -787,6 +793,7 @@ def unified_attention( USE_ALIBI_SLOPES=use_alibi_slopes, USE_QQ_BIAS=use_qq_bias, USE_SOFTCAP=(softcap > 0), + USE_SINKS=(sinks is not None), SLIDING_WINDOW=(1 + window_size[0]), stride_k_cache_0=k.stride(0), stride_k_cache_1=k.stride(1),