diff --git a/vllm/attention/ops/chunked_prefill_paged_decode.py b/vllm/attention/ops/chunked_prefill_paged_decode.py index 16d67e3abe848..48db3ebfd7412 100644 --- a/vllm/attention/ops/chunked_prefill_paged_decode.py +++ b/vllm/attention/ops/chunked_prefill_paged_decode.py @@ -1,9 +1,10 @@ # SPDX-License-Identifier: Apache-2.0 # Authors: -# - Burkhard Ringlein -# - Jan van Lunteren -# - Thomas Parnell +# - Burkhard Ringlein +# - Jan van Lunteren +# - Chih-Chieh Yang +# - Thomas Parnell import torch import triton @@ -31,6 +32,7 @@ def kernel_paged_attention_2d( v_scale, # float32 num_query_heads: tl.constexpr, # int num_queries_per_kv: tl.constexpr, # int + num_queries_per_kv_padded: tl.constexpr, # int block_table_stride: tl.constexpr, # int query_stride_0: tl.constexpr, # int query_stride_1: tl.constexpr, # int, should be equal to head_size @@ -55,8 +57,7 @@ def kernel_paged_attention_2d( query_start_len_ptr, # [num_seqs+1] ): seq_idx = tl.program_id(0) - query_head_idx = tl.program_id(1) - kv_head_idx = query_head_idx // num_queries_per_kv + kv_head_idx = tl.program_id(1) if filter_by_query_len: cur_batch_in_all_start_index = tl.load(query_start_len_ptr + seq_idx) @@ -69,31 +70,40 @@ def kernel_paged_attention_2d( else: cur_batch_in_all_start_index = seq_idx + query_head_idx = kv_head_idx * num_queries_per_kv + tl.arange( + 0, num_queries_per_kv_padded) + query_offset = (cur_batch_in_all_start_index * query_stride_0 + - query_head_idx * query_stride_1) + query_head_idx[:, None] * query_stride_1) + + head_mask = query_head_idx < (kv_head_idx + 1) * num_queries_per_kv + head_mask = head_mask & (query_head_idx < num_query_heads) dim_mask = tl.where(tl.arange(0, HEAD_SIZE_PADDED) < HEAD_SIZE, 1, 0).to(tl.int1) - # Q : (HEAD_SIZE,) + # Q : (num_queries_per_kv, HEAD_SIZE,) Q = tl.load( - query_ptr + query_offset + tl.arange(0, HEAD_SIZE_PADDED), - mask=dim_mask, + query_ptr + query_offset + tl.arange(0, HEAD_SIZE_PADDED)[None, :], + mask=dim_mask[None, :] & head_mask[:, None], other=0.0, ) block_table_offset = seq_idx * block_table_stride - M = tl.full([1], float("-inf"), dtype=tl.float32) - L = tl.full([1], 1.0, dtype=tl.float32) - acc = tl.zeros([HEAD_SIZE_PADDED], dtype=tl.float32) + M = tl.full([num_queries_per_kv_padded], float("-inf"), dtype=tl.float32) + L = tl.full([num_queries_per_kv_padded], 1.0, dtype=tl.float32) + acc = tl.zeros([num_queries_per_kv_padded, HEAD_SIZE_PADDED], + dtype=tl.float32) # sequence len for this particular sequence seq_len = tl.load(seq_lens_ptr + seq_idx) # alibi slope for this head if USE_ALIBI_SLOPES: - alibi_slope = tl.load(alibi_slopes_ptr + query_head_idx) + alibi_slope = tl.load(alibi_slopes_ptr + query_head_idx, + mask=head_mask, + other=0.0) num_blocks = cdiv_fn(seq_len, BLOCK_SIZE) @@ -107,8 +117,8 @@ def kernel_paged_attention_2d( v_offset = (physical_block_idx * stride_v_cache_0 + kv_head_idx * stride_v_cache_1 + - offs_d[:, None] * stride_v_cache_2 + - offs_n[None, :] * stride_v_cache_3) + offs_d[None, :] * stride_v_cache_2 + + offs_n[:, None] * stride_v_cache_3) k_offset = (physical_block_idx * stride_k_cache_0 + kv_head_idx * stride_k_cache_1 + @@ -126,9 +136,9 @@ def kernel_paged_attention_2d( else: K = K_load - # V : (HEAD_SIZE, BLOCK_SIZE) + # V : (BLOCK_SIZE, HEAD_SIZE) V_load = tl.load(value_cache_ptr + v_offset, - mask=dim_mask[:, None], + mask=dim_mask[None, :], other=0.0) if V_load.dtype.is_fp8(): @@ -136,51 +146,59 @@ def kernel_paged_attention_2d( else: V = V_load - tmp = j * BLOCK_SIZE + tl.arange(0, BLOCK_SIZE) + seq_offset = j * BLOCK_SIZE + tl.arange(0, BLOCK_SIZE) boundary = tl.full([BLOCK_SIZE], seq_len, dtype=tl.int32) - mask_new = tmp < boundary - # S : (BLOCK_SIZE,) - S = tl.where(mask_new, 0.0, float("-inf")).to(tl.float32) - S += scale * tl.sum(K * Q[:, None], axis=0) + seq_mask = seq_offset[None, :] < boundary + + # S : (num_queries_per_kv, BLOCK_SIZE,) + S = tl.where(head_mask[:, None] & seq_mask, 0.0, + float("-inf")).to(tl.float32) + S += scale * tl.dot(Q, K) + + context_len = seq_len - 1 if SLIDING_WINDOW > 0: - S = tl.where((seq_len - 1 - tmp) < SLIDING_WINDOW, S, -10000) + S = tl.where((context_len - seq_offset) < SLIDING_WINDOW, S, + -10000) if USE_ALIBI_SLOPES: - S += alibi_slope * (tmp - seq_len + 1) + S += alibi_slope[:, None] * (seq_offset - context_len) # compute running maximum - # m_j : (1,) - m_j = tl.maximum(M, tl.max(S, axis=0)) + # m_j : (num_queries_per_kv,) + m_j = tl.maximum(M, tl.max(S, axis=1)) - # P : (BLOCK_SIZE,) - P = tl.exp(S - m_j) + # P : (num_queries_per_kv, BLOCK_SIZE,) + P = tl.exp(S - m_j[:, None]) - # l_j : (1,) - l_j = tl.sum(P, axis=0) + # l_j : (num_queries_per_kv,) + l_j = tl.sum(P, axis=1) - # alpha : (1, ) + # alpha : (num_queries_per_kv, ) alpha = tl.exp(M - m_j) - # acc : (BLOCK_SIZE,) - acc = acc * alpha + # acc : (num_queries_per_kv, BLOCK_SIZE,) + acc = acc * alpha[:, None] # update constants L = L * alpha + l_j M = m_j - # acc : (BLOCK_SIZE,) - acc += tl.sum(V * P[None, :], axis=1) + # acc : (num_queries_per_kv, BLOCK_SIZE,) + acc += tl.dot(P.to(V.dtype), V) # epilogue - acc = acc / L + acc = acc / L[:, None] output_offset = (cur_batch_in_all_start_index * output_stride_0 + query_head_idx * output_stride_1) - tl.store(output_ptr + output_offset + tl.arange(0, HEAD_SIZE_PADDED), - acc, - mask=dim_mask) + tl.store( + output_ptr + output_offset[:, None] + + tl.arange(0, HEAD_SIZE_PADDED)[None, :], + acc, + mask=dim_mask[None, :] & head_mask[:, None], + ) def chunked_prefill_paged_decode( @@ -234,6 +252,7 @@ def chunked_prefill_paged_decode( block_size = value_cache.shape[3] num_seqs = len(seq_lens) num_query_heads = query.shape[1] + num_kv_heads = key.shape[1] num_queries_per_kv = query.shape[1] // key.shape[1] head_size = query.shape[2] @@ -253,9 +272,12 @@ def chunked_prefill_paged_decode( key_cache = key_cache.view(target_dtype) value_cache = value_cache.view(target_dtype) + num_queries_per_kv_padded = max(triton.next_power_of_2(num_queries_per_kv), + 16) + kernel_paged_attention_2d[( num_seqs, - num_query_heads, + num_kv_heads, )]( output_ptr=output, query_ptr=query, @@ -269,6 +291,7 @@ def chunked_prefill_paged_decode( v_scale=v_scale, num_query_heads=num_query_heads, num_queries_per_kv=num_queries_per_kv, + num_queries_per_kv_padded=num_queries_per_kv_padded, block_table_stride=block_table.stride(0), query_stride_0=query.stride(0), query_stride_1=query.stride(1),