[Kernel] Decouple Tile Size from Block Size in Triton Unified Attention Kernel (#21197)

Signed-off-by: Jan van Lunteren <jvl@zurich.ibm.com>
This commit is contained in:
jvlunteren 2025-09-18 16:27:01 +02:00 committed by GitHub
parent bc19d75985
commit 01a583fea4
No known key found for this signature in database
GPG Key ID: B5690EEEBB952194
2 changed files with 70 additions and 55 deletions

View File

@ -102,9 +102,6 @@ def test_triton_unified_attn(
) -> None: ) -> None:
torch.set_default_device("cuda") torch.set_default_device("cuda")
if q_dtype is not None and q_dtype.itemsize < 2 and block_size < 32:
pytest.skip("block size must be at least 32 for fp8")
current_platform.seed_everything(0) current_platform.seed_everything(0)
num_seqs = len(seq_lens) num_seqs = len(seq_lens)
query_lens = [x[0] for x in seq_lens] query_lens = [x[0] for x in seq_lens]

View File

@ -73,6 +73,7 @@ def kernel_unified_attention_2d(
output_stride_1: tl.int64, # int, should be equal to head_size output_stride_1: tl.int64, # int, should be equal to head_size
qq_bias_stride_0: tl.int64, # int qq_bias_stride_0: tl.int64, # int
BLOCK_SIZE: tl.constexpr, # int BLOCK_SIZE: tl.constexpr, # int
TILE_SIZE: tl.constexpr, # int must be power of 2
HEAD_SIZE: tl.constexpr, # int HEAD_SIZE: tl.constexpr, # int
HEAD_SIZE_PADDED: tl.constexpr, # int, must be power of 2 HEAD_SIZE_PADDED: tl.constexpr, # int, must be power of 2
USE_ALIBI_SLOPES: tl.constexpr, # bool USE_ALIBI_SLOPES: tl.constexpr, # bool
@ -118,6 +119,7 @@ def kernel_unified_attention_2d(
offs_m = tl.arange(0, BLOCK_M) offs_m = tl.arange(0, BLOCK_M)
offs_d = tl.arange(0, HEAD_SIZE_PADDED) offs_d = tl.arange(0, HEAD_SIZE_PADDED)
offs_t = tl.arange(0, TILE_SIZE)
query_pos = q_block_local_idx * BLOCK_Q + offs_m // num_queries_per_kv query_pos = q_block_local_idx * BLOCK_Q + offs_m // num_queries_per_kv
query_offset_0 = cur_batch_in_all_start_index + query_pos query_offset_0 = cur_batch_in_all_start_index + query_pos
@ -177,31 +179,32 @@ def kernel_unified_attention_2d(
# actual sequence length # actual sequence length
max_seq_prefix_len = tl.minimum(max_seq_prefix_len, seq_len) max_seq_prefix_len = tl.minimum(max_seq_prefix_len, seq_len)
# calculate the number of tiles (blocks) that need to be processed to # calculate the number of tiles that need to be processed to
# cover the longest sequence prefix (due to causal masking, blocks beyond # cover the longest sequence prefix (due to causal masking, tiles beyond
# this prefix can be skipped) # this prefix can be skipped)
num_blocks = cdiv_fn(max_seq_prefix_len, BLOCK_SIZE) num_tiles = cdiv_fn(max_seq_prefix_len, TILE_SIZE)
# iterate through tiles # iterate through tiles
for j in range(0, num_blocks): for j in range(0, num_tiles):
seq_offset = j * TILE_SIZE + offs_t
tile_mask = seq_offset < max_seq_prefix_len
physical_block_idx = tl.load(block_tables_ptr + block_table_offset + j) physical_block_idx = tl.load(block_tables_ptr + block_table_offset +
seq_offset // BLOCK_SIZE).to(tl.int64)
offs_n = tl.arange(0, BLOCK_SIZE) v_offset = (physical_block_idx[:, None] * stride_v_cache_0 +
v_offset = (physical_block_idx * stride_v_cache_0 +
kv_head_idx * stride_v_cache_2 + kv_head_idx * stride_v_cache_2 +
offs_d[None, :] * stride_v_cache_3 + offs_d[None, :] * stride_v_cache_3 +
offs_n[:, None] * stride_v_cache_1) (seq_offset % BLOCK_SIZE)[:, None] * stride_v_cache_1)
k_offset = (physical_block_idx * stride_k_cache_0 + k_offset = (physical_block_idx[None, :] * stride_k_cache_0 +
kv_head_idx * stride_k_cache_2 + kv_head_idx * stride_k_cache_2 +
offs_d[:, None] * stride_k_cache_3 + offs_d[:, None] * stride_k_cache_3 +
offs_n[None, :] * stride_k_cache_1) (seq_offset % BLOCK_SIZE)[None, :] * stride_k_cache_1)
# K : (HEAD_SIZE, BLOCK_SIZE) # K : (HEAD_SIZE, TILE_SIZE)
K_load = tl.load(key_cache_ptr + k_offset, K_load = tl.load(key_cache_ptr + k_offset,
mask=dim_mask[:, None], mask=dim_mask[:, None] & tile_mask[None, :],
other=0.0) other=0.0)
if K_load.dtype.is_fp8(): if K_load.dtype.is_fp8():
@ -212,9 +215,9 @@ def kernel_unified_attention_2d(
else: else:
K = K_load K = K_load
# V : (BLOCK_SIZE, HEAD_SIZE) # V : (TILE_SIZE, HEAD_SIZE)
V_load = tl.load(value_cache_ptr + v_offset, V_load = tl.load(value_cache_ptr + v_offset,
mask=dim_mask[None, :], mask=dim_mask[None, :] & tile_mask[:, None],
other=0.0) other=0.0)
if V_load.dtype.is_fp8(): if V_load.dtype.is_fp8():
@ -225,12 +228,10 @@ def kernel_unified_attention_2d(
else: else:
V = V_load V = V_load
seq_offset = j * BLOCK_SIZE + offs_n
seq_mask = seq_offset[None, :] < context_len + query_pos[:, None] + 1 seq_mask = seq_offset[None, :] < context_len + query_pos[:, None] + 1
# S : (BLOCK_M, BLOCK_SIZE) # S : (BLOCK_M, TILE_SIZE)
S = tl.zeros(shape=(BLOCK_M, BLOCK_SIZE), dtype=tl.float32) S = tl.zeros(shape=(BLOCK_M, TILE_SIZE), dtype=tl.float32)
S += scale * tl.dot(Q, K) S += scale * tl.dot(Q, K)
@ -262,11 +263,12 @@ def kernel_unified_attention_2d(
# compute running maximum # compute running maximum
# m_j : (BLOCK_M,) # m_j : (BLOCK_M,)
m_j = tl.maximum(M, tl.max(S, axis=1)) m_j = tl.maximum(M, tl.max(S, axis=1))
# For sliding window there's a chance the max is -inf due to masking of # For sliding window there's a chance the max is -inf due to masking of
# the entire row. In this case we need to set m_j 0 to avoid NaN # the entire row. In this case we need to set m_j 0 to avoid NaN
m_j = tl.where(m_j > float("-inf"), m_j, 0.0) m_j = tl.where(m_j > float("-inf"), m_j, 0.0)
# P : (BLOCK_M, BLOCK_SIZE) # P : (BLOCK_M, TILE_SIZE)
P = tl.exp(S - m_j[:, None]) P = tl.exp(S - m_j[:, None])
# l_j : (BLOCK_M,) # l_j : (BLOCK_M,)
@ -327,6 +329,7 @@ def kernel_unified_attention_3d(
query_stride_1: tl.int64, # int, should be equal to head_size query_stride_1: tl.int64, # int, should be equal to head_size
qq_bias_stride_0: tl.int64, # int qq_bias_stride_0: tl.int64, # int
BLOCK_SIZE: tl.constexpr, # int BLOCK_SIZE: tl.constexpr, # int
TILE_SIZE: tl.constexpr, # int, must be power of 2
HEAD_SIZE: tl.constexpr, # int HEAD_SIZE: tl.constexpr, # int
HEAD_SIZE_PADDED: tl.constexpr, # int, must be power of 2 HEAD_SIZE_PADDED: tl.constexpr, # int, must be power of 2
USE_ALIBI_SLOPES: tl.constexpr, # bool USE_ALIBI_SLOPES: tl.constexpr, # bool
@ -374,20 +377,19 @@ def kernel_unified_attention_3d(
# number of segments for this particular sequence # number of segments for this particular sequence
num_segments = NUM_SEGMENTS_PER_SEQ num_segments = NUM_SEGMENTS_PER_SEQ
blocks_per_segment = cdiv_fn(seq_len, num_segments * BLOCK_SIZE) tiles_per_segment = cdiv_fn(seq_len, num_segments * TILE_SIZE)
if segm_idx * blocks_per_segment * BLOCK_SIZE >= seq_len: if segm_idx * tiles_per_segment * TILE_SIZE >= seq_len:
return return
offs_m = tl.arange(0, BLOCK_M) offs_m = tl.arange(0, BLOCK_M)
offs_d = tl.arange(0, HEAD_SIZE_PADDED) offs_d = tl.arange(0, HEAD_SIZE_PADDED)
offs_t = tl.arange(0, TILE_SIZE)
query_pos = q_block_local_idx * BLOCK_Q + offs_m // num_queries_per_kv query_pos = q_block_local_idx * BLOCK_Q + offs_m // num_queries_per_kv
query_offset_0 = cur_batch_in_all_start_index + query_pos query_offset_0 = cur_batch_in_all_start_index + query_pos
query_offset_1 = kv_head_idx * num_queries_per_kv + \ query_offset_1 = kv_head_idx * num_queries_per_kv + \
offs_m % num_queries_per_kv offs_m % num_queries_per_kv
query_offset = (query_offset_0[:, None] * query_stride_0 + query_offset = (query_offset_0[:, None] * query_stride_0 +
query_offset_1[:, None] * query_stride_1 + offs_d[None, :]) query_offset_1[:, None] * query_stride_1 + offs_d[None, :])
@ -433,30 +435,44 @@ def kernel_unified_attention_3d(
qq_bias_row_ptrs = (qq_bias_ptr + query_pos[:, None] * qq_bias_stride_0 qq_bias_row_ptrs = (qq_bias_ptr + query_pos[:, None] * qq_bias_stride_0
) # shape: [BLOCK_M] ) # shape: [BLOCK_M]
num_blocks = cdiv_fn(seq_len, BLOCK_SIZE) # compute the length of the longest sequence prefix spanned by any
# query token in the current q_block (q_block_local_idx)
max_seq_prefix_len = context_len + q_block_local_idx * BLOCK_Q + (
BLOCK_M - 1) // num_queries_per_kv + 1
# adjust for potential padding in the last q_block by considering the
# actual sequence length
max_seq_prefix_len = tl.minimum(max_seq_prefix_len, seq_len)
# calculate the number of tiles that need to be processed to
# cover the longest sequence prefix (due to causal masking, tiles beyond
# this prefix can be skipped)
num_tiles = cdiv_fn(max_seq_prefix_len, TILE_SIZE)
# iterate through tiles within current segment # iterate through tiles within current segment
for j in range( for j in range(
segm_idx * blocks_per_segment, segm_idx * tiles_per_segment,
min((segm_idx + 1) * blocks_per_segment, num_blocks), min((segm_idx + 1) * tiles_per_segment, num_tiles),
): ):
physical_block_idx = tl.load(block_tables_ptr + block_table_offset + j) seq_offset = j * TILE_SIZE + offs_t
tile_mask = seq_offset < max_seq_prefix_len
offs_n = tl.arange(0, BLOCK_SIZE) physical_block_idx = tl.load(block_tables_ptr + block_table_offset +
seq_offset // BLOCK_SIZE).to(tl.int64)
v_offset = (physical_block_idx * stride_v_cache_0 + v_offset = (physical_block_idx[:, None] * stride_v_cache_0 +
kv_head_idx * stride_v_cache_2 + kv_head_idx * stride_v_cache_2 +
offs_d[None, :] * stride_v_cache_3 + offs_d[None, :] * stride_v_cache_3 +
offs_n[:, None] * stride_v_cache_1) (seq_offset % BLOCK_SIZE)[:, None] * stride_v_cache_1)
k_offset = (physical_block_idx * stride_k_cache_0 + k_offset = (physical_block_idx[None, :] * stride_k_cache_0 +
kv_head_idx * stride_k_cache_2 + kv_head_idx * stride_k_cache_2 +
offs_d[:, None] * stride_k_cache_3 + offs_d[:, None] * stride_k_cache_3 +
offs_n[None, :] * stride_k_cache_1) (seq_offset % BLOCK_SIZE)[None, :] * stride_k_cache_1)
# K : (HEAD_SIZE, BLOCK_SIZE) # K : (HEAD_SIZE, TILE_SIZE)
K_load = tl.load(key_cache_ptr + k_offset, K_load = tl.load(key_cache_ptr + k_offset,
mask=dim_mask[:, None], mask=dim_mask[:, None] & tile_mask[None, :],
other=0.0) other=0.0)
if K_load.dtype.is_fp8(): if K_load.dtype.is_fp8():
@ -467,9 +483,9 @@ def kernel_unified_attention_3d(
else: else:
K = K_load K = K_load
# V : (BLOCK_SIZE, HEAD_SIZE) # V : (TILE_SIZE, HEAD_SIZE)
V_load = tl.load(value_cache_ptr + v_offset, V_load = tl.load(value_cache_ptr + v_offset,
mask=dim_mask[None, :], mask=dim_mask[None, :] & tile_mask[:, None],
other=0.0) other=0.0)
if V_load.dtype.is_fp8(): if V_load.dtype.is_fp8():
@ -480,13 +496,10 @@ def kernel_unified_attention_3d(
else: else:
V = V_load V = V_load
seq_offset = j * BLOCK_SIZE + offs_n
seq_mask = seq_offset[None, :] < context_len + query_pos[:, None] + 1 seq_mask = seq_offset[None, :] < context_len + query_pos[:, None] + 1
# S : (BLOCK_M, BLOCK_SIZE) # S : (BLOCK_M, TILE_SIZE)
S = tl.zeros(shape=(BLOCK_M, BLOCK_SIZE), dtype=tl.float32) S = tl.zeros(shape=(BLOCK_M, TILE_SIZE), dtype=tl.float32)
S += scale * tl.dot(Q, K) S += scale * tl.dot(Q, K)
if USE_SOFTCAP: if USE_SOFTCAP:
@ -517,11 +530,12 @@ def kernel_unified_attention_3d(
# compute running maximum # compute running maximum
# m_j : (BLOCK_M,) # m_j : (BLOCK_M,)
m_j = tl.maximum(M, tl.max(S, axis=1)) m_j = tl.maximum(M, tl.max(S, axis=1))
# For sliding window there's a chance the max is -inf due to masking of # For sliding window there's a chance the max is -inf due to masking of
# the entire row. In this case we need to set m_j 0 to avoid NaN # the entire row. In this case we need to set m_j 0 to avoid NaN
m_j = tl.where(m_j > float("-inf"), m_j, 0.0) m_j = tl.where(m_j > float("-inf"), m_j, 0.0)
# P : (BLOCK_M, BLOCK_SIZE,) # P : (BLOCK_M, TILE_SIZE,)
P = tl.exp(S - m_j[:, None]) P = tl.exp(S - m_j[:, None])
# l_j : (BLOCK_M,) # l_j : (BLOCK_M,)
@ -573,7 +587,7 @@ def reduce_segments(
output_stride_0: tl.int64, # int output_stride_0: tl.int64, # int
output_stride_1: tl.int64, # int, should be equal to head_size output_stride_1: tl.int64, # int, should be equal to head_size
block_table_stride: tl.int64, # int block_table_stride: tl.int64, # int
BLOCK_SIZE: tl.constexpr, # int TILE_SIZE: tl.constexpr, # int
HEAD_SIZE: tl.constexpr, # int, must be power of 2 HEAD_SIZE: tl.constexpr, # int, must be power of 2
HEAD_SIZE_PADDED: tl.constexpr, # int, must be power of 2 HEAD_SIZE_PADDED: tl.constexpr, # int, must be power of 2
query_start_len_ptr, # [num_seqs+1] query_start_len_ptr, # [num_seqs+1]
@ -594,10 +608,10 @@ def reduce_segments(
# number of segments for this particular sequence # number of segments for this particular sequence
num_segments = NUM_SEGMENTS_PER_SEQ num_segments = NUM_SEGMENTS_PER_SEQ
blocks_per_segment = cdiv_fn(seq_len, num_segments * BLOCK_SIZE) tiles_per_segment = cdiv_fn(seq_len, num_segments * TILE_SIZE)
# create masks for subsequent loads # create masks for subsequent loads
act_num_segments = cdiv_fn(seq_len, blocks_per_segment * BLOCK_SIZE) act_num_segments = cdiv_fn(seq_len, tiles_per_segment * TILE_SIZE)
segm_mask = tl.arange(0, NUM_SEGMENTS_PER_SEQ) < tl.full( segm_mask = tl.arange(0, NUM_SEGMENTS_PER_SEQ) < tl.full(
[NUM_SEGMENTS_PER_SEQ], act_num_segments, dtype=tl.int32) [NUM_SEGMENTS_PER_SEQ], act_num_segments, dtype=tl.int32)
dim_mask = tl.where(tl.arange(0, HEAD_SIZE_PADDED) < HEAD_SIZE, 1, dim_mask = tl.where(tl.arange(0, HEAD_SIZE_PADDED) < HEAD_SIZE, 1,
@ -671,13 +685,10 @@ def unified_attention(
# Optional tensor for sinks # Optional tensor for sinks
sinks=None, sinks=None,
): ):
assert causal, "Only causal attention is supported" assert causal, "Only causal attention is supported"
assert q_descale is None, "Q scales not supported" assert q_descale is None, "Q scales not supported"
block_size = v.shape[1]
assert q.element_size() >= 2 or block_size >= 32, \
"Block size must be at least 32 for fp8"
if sinks is not None: if sinks is not None:
assert sinks.shape[0] == q.shape[1], \ assert sinks.shape[0] == q.shape[1], \
"Sinks must be num_query_heads size" "Sinks must be num_query_heads size"
@ -707,6 +718,12 @@ def unified_attention(
# = floor(q.shape[0] / BLOCK_Q) + num_seqs # = floor(q.shape[0] / BLOCK_Q) + num_seqs
total_num_q_blocks = q.shape[0] // BLOCK_Q + num_seqs total_num_q_blocks = q.shape[0] // BLOCK_Q + num_seqs
# Assigning default tile sizes for prefill and decode.
# Note: each tile size must be at least 32 for "fp8" (q.element_size() == 1)
# and at least 16 for all other data types.
TILE_SIZE_PREFILL = 32
TILE_SIZE_DECODE = 16 if q.element_size() >= 2 else 32
# if batch contains a prefill # if batch contains a prefill
if max_seqlen_q > 1 or total_num_q_blocks * num_kv_heads > 128: if max_seqlen_q > 1 or total_num_q_blocks * num_kv_heads > 128:
kernel_unified_attention_2d[( kernel_unified_attention_2d[(
@ -736,6 +753,7 @@ def unified_attention(
output_stride_1=out.stride(1), output_stride_1=out.stride(1),
qq_bias_stride_0=qq_bias.stride(0) if use_qq_bias else 0, qq_bias_stride_0=qq_bias.stride(0) if use_qq_bias else 0,
BLOCK_SIZE=block_size, BLOCK_SIZE=block_size,
TILE_SIZE=TILE_SIZE_PREFILL,
HEAD_SIZE=head_size, HEAD_SIZE=head_size,
HEAD_SIZE_PADDED=triton.next_power_of_2(head_size), HEAD_SIZE_PADDED=triton.next_power_of_2(head_size),
USE_ALIBI_SLOPES=use_alibi_slopes, USE_ALIBI_SLOPES=use_alibi_slopes,
@ -809,6 +827,7 @@ def unified_attention(
query_stride_1=q.stride(1), query_stride_1=q.stride(1),
qq_bias_stride_0=qq_bias.stride(0) if use_qq_bias else 0, qq_bias_stride_0=qq_bias.stride(0) if use_qq_bias else 0,
BLOCK_SIZE=block_size, BLOCK_SIZE=block_size,
TILE_SIZE=TILE_SIZE_DECODE,
HEAD_SIZE=head_size, HEAD_SIZE=head_size,
HEAD_SIZE_PADDED=triton.next_power_of_2(head_size), HEAD_SIZE_PADDED=triton.next_power_of_2(head_size),
USE_ALIBI_SLOPES=use_alibi_slopes, USE_ALIBI_SLOPES=use_alibi_slopes,
@ -830,7 +849,6 @@ def unified_attention(
BLOCK_M=BLOCK_M, BLOCK_M=BLOCK_M,
NUM_SEGMENTS_PER_SEQ=NUM_SEGMENTS, NUM_SEGMENTS_PER_SEQ=NUM_SEGMENTS,
) )
reduce_segments[(q.shape[0], num_query_heads)]( reduce_segments[(q.shape[0], num_query_heads)](
output_ptr=out, output_ptr=out,
segm_output_ptr=segm_output, segm_output_ptr=segm_output,
@ -844,7 +862,7 @@ def unified_attention(
output_stride_0=out.stride(0), output_stride_0=out.stride(0),
output_stride_1=out.stride(1), output_stride_1=out.stride(1),
block_table_stride=block_table.stride(0), block_table_stride=block_table.stride(0),
BLOCK_SIZE=block_size, TILE_SIZE=TILE_SIZE_DECODE,
HEAD_SIZE=head_size, HEAD_SIZE=head_size,
HEAD_SIZE_PADDED=triton.next_power_of_2(head_size), HEAD_SIZE_PADDED=triton.next_power_of_2(head_size),
query_start_len_ptr=cu_seqlens_q, query_start_len_ptr=cu_seqlens_q,