[Kernel] [V1] Further optimizations to ROCm (Triton) Backend to better handle GQA. (#14431)

Signed-off-by: Thomas Parnell <tpa@zurich.ibm.com>
Co-authored-by: Jan van Lunteren <jvl@zurich.ibm.com>
Co-authored-by: Burkhard Ringlein <ngl@zurich.ibm.com>
Co-authored-by: Chih-Chieh Yang <chih.chieh.yang@ibm.com>
This commit is contained in:
Thomas Parnell 2025-03-14 04:42:27 +01:00 committed by GitHub
parent 0b1cfa6180
commit fb4c7f8ef0
No known key found for this signature in database
GPG Key ID: B5690EEEBB952194

View File

@ -1,9 +1,10 @@
# SPDX-License-Identifier: Apache-2.0
# Authors:
# - Burkhard Ringlein
# - Jan van Lunteren
# - Thomas Parnell
# - Burkhard Ringlein <ngl@zurich.ibm.com>
# - Jan van Lunteren <jvl@zurich.ibm.com>
# - Chih-Chieh Yang <chih.chieh.yang@ibm.com>
# - Thomas Parnell <tpa@zurich.ibm.com>
import torch
import triton
@ -31,6 +32,7 @@ def kernel_paged_attention_2d(
v_scale, # float32
num_query_heads: tl.constexpr, # int
num_queries_per_kv: tl.constexpr, # int
num_queries_per_kv_padded: tl.constexpr, # int
block_table_stride: tl.constexpr, # int
query_stride_0: tl.constexpr, # int
query_stride_1: tl.constexpr, # int, should be equal to head_size
@ -55,8 +57,7 @@ def kernel_paged_attention_2d(
query_start_len_ptr, # [num_seqs+1]
):
seq_idx = tl.program_id(0)
query_head_idx = tl.program_id(1)
kv_head_idx = query_head_idx // num_queries_per_kv
kv_head_idx = tl.program_id(1)
if filter_by_query_len:
cur_batch_in_all_start_index = tl.load(query_start_len_ptr + seq_idx)
@ -69,31 +70,40 @@ def kernel_paged_attention_2d(
else:
cur_batch_in_all_start_index = seq_idx
query_head_idx = kv_head_idx * num_queries_per_kv + tl.arange(
0, num_queries_per_kv_padded)
query_offset = (cur_batch_in_all_start_index * query_stride_0 +
query_head_idx * query_stride_1)
query_head_idx[:, None] * query_stride_1)
head_mask = query_head_idx < (kv_head_idx + 1) * num_queries_per_kv
head_mask = head_mask & (query_head_idx < num_query_heads)
dim_mask = tl.where(tl.arange(0, HEAD_SIZE_PADDED) < HEAD_SIZE, 1,
0).to(tl.int1)
# Q : (HEAD_SIZE,)
# Q : (num_queries_per_kv, HEAD_SIZE,)
Q = tl.load(
query_ptr + query_offset + tl.arange(0, HEAD_SIZE_PADDED),
mask=dim_mask,
query_ptr + query_offset + tl.arange(0, HEAD_SIZE_PADDED)[None, :],
mask=dim_mask[None, :] & head_mask[:, None],
other=0.0,
)
block_table_offset = seq_idx * block_table_stride
M = tl.full([1], float("-inf"), dtype=tl.float32)
L = tl.full([1], 1.0, dtype=tl.float32)
acc = tl.zeros([HEAD_SIZE_PADDED], dtype=tl.float32)
M = tl.full([num_queries_per_kv_padded], float("-inf"), dtype=tl.float32)
L = tl.full([num_queries_per_kv_padded], 1.0, dtype=tl.float32)
acc = tl.zeros([num_queries_per_kv_padded, HEAD_SIZE_PADDED],
dtype=tl.float32)
# sequence len for this particular sequence
seq_len = tl.load(seq_lens_ptr + seq_idx)
# alibi slope for this head
if USE_ALIBI_SLOPES:
alibi_slope = tl.load(alibi_slopes_ptr + query_head_idx)
alibi_slope = tl.load(alibi_slopes_ptr + query_head_idx,
mask=head_mask,
other=0.0)
num_blocks = cdiv_fn(seq_len, BLOCK_SIZE)
@ -107,8 +117,8 @@ def kernel_paged_attention_2d(
v_offset = (physical_block_idx * stride_v_cache_0 +
kv_head_idx * stride_v_cache_1 +
offs_d[:, None] * stride_v_cache_2 +
offs_n[None, :] * stride_v_cache_3)
offs_d[None, :] * stride_v_cache_2 +
offs_n[:, None] * stride_v_cache_3)
k_offset = (physical_block_idx * stride_k_cache_0 +
kv_head_idx * stride_k_cache_1 +
@ -126,9 +136,9 @@ def kernel_paged_attention_2d(
else:
K = K_load
# V : (HEAD_SIZE, BLOCK_SIZE)
# V : (BLOCK_SIZE, HEAD_SIZE)
V_load = tl.load(value_cache_ptr + v_offset,
mask=dim_mask[:, None],
mask=dim_mask[None, :],
other=0.0)
if V_load.dtype.is_fp8():
@ -136,51 +146,59 @@ def kernel_paged_attention_2d(
else:
V = V_load
tmp = j * BLOCK_SIZE + tl.arange(0, BLOCK_SIZE)
seq_offset = j * BLOCK_SIZE + tl.arange(0, BLOCK_SIZE)
boundary = tl.full([BLOCK_SIZE], seq_len, dtype=tl.int32)
mask_new = tmp < boundary
# S : (BLOCK_SIZE,)
S = tl.where(mask_new, 0.0, float("-inf")).to(tl.float32)
S += scale * tl.sum(K * Q[:, None], axis=0)
seq_mask = seq_offset[None, :] < boundary
# S : (num_queries_per_kv, BLOCK_SIZE,)
S = tl.where(head_mask[:, None] & seq_mask, 0.0,
float("-inf")).to(tl.float32)
S += scale * tl.dot(Q, K)
context_len = seq_len - 1
if SLIDING_WINDOW > 0:
S = tl.where((seq_len - 1 - tmp) < SLIDING_WINDOW, S, -10000)
S = tl.where((context_len - seq_offset) < SLIDING_WINDOW, S,
-10000)
if USE_ALIBI_SLOPES:
S += alibi_slope * (tmp - seq_len + 1)
S += alibi_slope[:, None] * (seq_offset - context_len)
# compute running maximum
# m_j : (1,)
m_j = tl.maximum(M, tl.max(S, axis=0))
# m_j : (num_queries_per_kv,)
m_j = tl.maximum(M, tl.max(S, axis=1))
# P : (BLOCK_SIZE,)
P = tl.exp(S - m_j)
# P : (num_queries_per_kv, BLOCK_SIZE,)
P = tl.exp(S - m_j[:, None])
# l_j : (1,)
l_j = tl.sum(P, axis=0)
# l_j : (num_queries_per_kv,)
l_j = tl.sum(P, axis=1)
# alpha : (1, )
# alpha : (num_queries_per_kv, )
alpha = tl.exp(M - m_j)
# acc : (BLOCK_SIZE,)
acc = acc * alpha
# acc : (num_queries_per_kv, BLOCK_SIZE,)
acc = acc * alpha[:, None]
# update constants
L = L * alpha + l_j
M = m_j
# acc : (BLOCK_SIZE,)
acc += tl.sum(V * P[None, :], axis=1)
# acc : (num_queries_per_kv, BLOCK_SIZE,)
acc += tl.dot(P.to(V.dtype), V)
# epilogue
acc = acc / L
acc = acc / L[:, None]
output_offset = (cur_batch_in_all_start_index * output_stride_0 +
query_head_idx * output_stride_1)
tl.store(output_ptr + output_offset + tl.arange(0, HEAD_SIZE_PADDED),
acc,
mask=dim_mask)
tl.store(
output_ptr + output_offset[:, None] +
tl.arange(0, HEAD_SIZE_PADDED)[None, :],
acc,
mask=dim_mask[None, :] & head_mask[:, None],
)
def chunked_prefill_paged_decode(
@ -234,6 +252,7 @@ def chunked_prefill_paged_decode(
block_size = value_cache.shape[3]
num_seqs = len(seq_lens)
num_query_heads = query.shape[1]
num_kv_heads = key.shape[1]
num_queries_per_kv = query.shape[1] // key.shape[1]
head_size = query.shape[2]
@ -253,9 +272,12 @@ def chunked_prefill_paged_decode(
key_cache = key_cache.view(target_dtype)
value_cache = value_cache.view(target_dtype)
num_queries_per_kv_padded = max(triton.next_power_of_2(num_queries_per_kv),
16)
kernel_paged_attention_2d[(
num_seqs,
num_query_heads,
num_kv_heads,
)](
output_ptr=output,
query_ptr=query,
@ -269,6 +291,7 @@ def chunked_prefill_paged_decode(
v_scale=v_scale,
num_query_heads=num_query_heads,
num_queries_per_kv=num_queries_per_kv,
num_queries_per_kv_padded=num_queries_per_kv_padded,
block_table_stride=block_table.stride(0),
query_stride_0=query.stride(0),
query_stride_1=query.stride(1),