[Kernel] Optimize Prefill Attention in Unified Triton Attention Kernel (#20308)

Signed-off-by: Jan van Lunteren <jvl@zurich.ibm.com>
This commit is contained in:
jvlunteren 2025-07-07 21:08:12 +02:00 committed by GitHub
parent a6d795d593
commit 22dd9c2730
No known key found for this signature in database
GPG Key ID: B5690EEEBB952194

View File

@ -145,7 +145,19 @@ def kernel_unified_attention_2d(
mask=query_mask_1,
other=0.0)
num_blocks = cdiv_fn(seq_len, BLOCK_SIZE)
# compute the length of the longest sequence prefix spanned by any
# query token in the current q_block (q_block_local_idx)
max_seq_prefix_len = context_len + q_block_local_idx * BLOCK_Q + (
BLOCK_M - 1) // num_queries_per_kv + 1
# adjust for potential padding in the last q_block by considering the
# actual sequence length
max_seq_prefix_len = tl.minimum(max_seq_prefix_len, seq_len)
# calculate the number of tiles (blocks) that need to be processed to
# cover the longest sequence prefix (due to causal masking, blocks beyond
# this prefix can be skipped)
num_blocks = cdiv_fn(max_seq_prefix_len, BLOCK_SIZE)
# iterate through tiles
for j in range(0, num_blocks):