Optimize triton unified attention performance for sliding window attention (#24390)

Signed-off-by: zixi-qi <qizixi@meta.com>
Signed-off-by: yewentao256 <zhyanwentao@126.com>
This commit is contained in:
qizixi 2025-09-19 12:07:26 -07:00 committed by yewentao256
parent 378c68bead
commit 17b9f3a83d
2 changed files with 25 additions and 3 deletions

View File

@ -83,7 +83,7 @@ def ref_paged_attn(
@pytest.mark.parametrize("num_heads", NUM_HEADS)
@pytest.mark.parametrize("head_size", HEAD_SIZES)
@pytest.mark.parametrize("block_size", BLOCK_SIZES)
@pytest.mark.parametrize("sliding_window", [None, 256])
@pytest.mark.parametrize("sliding_window", [None, 64, 128, 256])
@pytest.mark.parametrize("dtype", DTYPES)
@pytest.mark.parametrize("soft_cap", [None, 50.0])
@pytest.mark.parametrize("num_blocks", NUM_BLOCKS)

View File

@ -184,8 +184,30 @@ def kernel_unified_attention_2d(
# this prefix can be skipped)
num_tiles = cdiv_fn(max_seq_prefix_len, TILE_SIZE)
# iterate through tiles
for j in range(0, num_tiles):
# ---- Sliding-window tile pruning --------------------
# Default: keep previous global behavior
tile_start = 0
tile_end = num_tiles
if SLIDING_WINDOW > 0:
# Query rows covered by this Q-block
qpos_lo = q_block_local_idx * BLOCK_Q
qpos_hi = tl.minimum(
qpos_lo + (BLOCK_M - 1) // num_queries_per_kv,
cur_batch_query_len - 1,
)
# For sliding window, each query position q can only attend to
# keys in the range [q_abs - SLIDING_WINDOW + 1, q_abs]
# where q_abs = context_len + q
# The union of allowed key positions for this Q-block is:
# [context_len + qpos_lo - SLIDING_WINDOW + 1, context_len + qpos_hi]
first_allowed_key = context_len + qpos_lo - SLIDING_WINDOW + 1
last_allowed_key = context_len + qpos_hi
# Convert to tile indices and clamp
tile_start = tl.maximum(0, first_allowed_key // TILE_SIZE)
tile_end = tl.minimum((last_allowed_key // TILE_SIZE) + 1, num_tiles)
# iterate through tiles (now limited to the sliding window range)
for j in range(tile_start, tile_end):
seq_offset = j * TILE_SIZE + offs_t
tile_mask = seq_offset < max_seq_prefix_len