diff --git a/vllm/attention/ops/chunked_prefill_paged_decode.py b/vllm/attention/ops/chunked_prefill_paged_decode.py index 08bfcc974cc95..dc10d7eca9c2a 100644 --- a/vllm/attention/ops/chunked_prefill_paged_decode.py +++ b/vllm/attention/ops/chunked_prefill_paged_decode.py @@ -60,6 +60,7 @@ def kernel_paged_attention_2d( stride_v_cache_3: tl.int64, # int filter_by_query_len: tl.constexpr, # bool query_start_len_ptr, # [num_seqs+1] + USE_SINKS: tl.constexpr, # bool ): seq_idx = tl.program_id(0) kv_head_idx = tl.program_id(1) @@ -96,7 +97,7 @@ def kernel_paged_attention_2d( block_table_offset = seq_idx * block_table_stride - if sink_ptr is None: + if not USE_SINKS: M = tl.full([num_queries_per_kv_padded], float("-inf"), dtype=tl.float32) @@ -386,4 +387,5 @@ def chunked_prefill_paged_decode( stride_v_cache_3=value_cache.stride(3), filter_by_query_len=True, query_start_len_ptr=query_start_loc, + USE_SINKS=sinks is not None, ) diff --git a/vllm/attention/ops/prefix_prefill.py b/vllm/attention/ops/prefix_prefill.py index 64c90337970f7..e1d41930f6231 100644 --- a/vllm/attention/ops/prefix_prefill.py +++ b/vllm/attention/ops/prefix_prefill.py @@ -81,6 +81,7 @@ def _fwd_kernel(Q, num_unroll_cache: tl.constexpr, num_unroll_request: tl.constexpr, SKIP_DECODE: tl.constexpr, + USE_SINKS: tl.constexpr, MAX_Q_LEN: tl.constexpr = 0, MAX_CTX_LEN: tl.constexpr = 0): @@ -127,7 +128,7 @@ def _fwd_kernel(Q, other=0.0) # [M,D] # initialize pointer to m and l - if sink_ptr is None: + if not USE_SINKS: m_i = tl.full([BLOCK_M], float("-inf"), dtype=tl.float32) else: m_i = tl.load( @@ -910,5 +911,6 @@ def context_attention_fwd(q, num_unroll_request=1, num_warps=4, num_stages=1, + USE_SINKS=sinks is not None, **extra_kargs) return