mirror of
https://git.datalinker.icu/vllm-project/vllm.git
synced 2026-01-08 14:09:51 +08:00
Optimize triton unified attention performance for sliding window attention (#24390)
Signed-off-by: zixi-qi <qizixi@meta.com> Signed-off-by: yewentao256 <zhyanwentao@126.com>
This commit is contained in:
parent
378c68bead
commit
17b9f3a83d
@ -83,7 +83,7 @@ def ref_paged_attn(
|
||||
@pytest.mark.parametrize("num_heads", NUM_HEADS)
|
||||
@pytest.mark.parametrize("head_size", HEAD_SIZES)
|
||||
@pytest.mark.parametrize("block_size", BLOCK_SIZES)
|
||||
@pytest.mark.parametrize("sliding_window", [None, 256])
|
||||
@pytest.mark.parametrize("sliding_window", [None, 64, 128, 256])
|
||||
@pytest.mark.parametrize("dtype", DTYPES)
|
||||
@pytest.mark.parametrize("soft_cap", [None, 50.0])
|
||||
@pytest.mark.parametrize("num_blocks", NUM_BLOCKS)
|
||||
|
||||
@ -184,8 +184,30 @@ def kernel_unified_attention_2d(
|
||||
# this prefix can be skipped)
|
||||
num_tiles = cdiv_fn(max_seq_prefix_len, TILE_SIZE)
|
||||
|
||||
# iterate through tiles
|
||||
for j in range(0, num_tiles):
|
||||
# ---- Sliding-window tile pruning --------------------
|
||||
# Default: keep previous global behavior
|
||||
tile_start = 0
|
||||
tile_end = num_tiles
|
||||
if SLIDING_WINDOW > 0:
|
||||
# Query rows covered by this Q-block
|
||||
qpos_lo = q_block_local_idx * BLOCK_Q
|
||||
qpos_hi = tl.minimum(
|
||||
qpos_lo + (BLOCK_M - 1) // num_queries_per_kv,
|
||||
cur_batch_query_len - 1,
|
||||
)
|
||||
# For sliding window, each query position q can only attend to
|
||||
# keys in the range [q_abs - SLIDING_WINDOW + 1, q_abs]
|
||||
# where q_abs = context_len + q
|
||||
# The union of allowed key positions for this Q-block is:
|
||||
# [context_len + qpos_lo - SLIDING_WINDOW + 1, context_len + qpos_hi]
|
||||
first_allowed_key = context_len + qpos_lo - SLIDING_WINDOW + 1
|
||||
last_allowed_key = context_len + qpos_hi
|
||||
# Convert to tile indices and clamp
|
||||
tile_start = tl.maximum(0, first_allowed_key // TILE_SIZE)
|
||||
tile_end = tl.minimum((last_allowed_key // TILE_SIZE) + 1, num_tiles)
|
||||
|
||||
# iterate through tiles (now limited to the sliding window range)
|
||||
for j in range(tile_start, tile_end):
|
||||
seq_offset = j * TILE_SIZE + offs_t
|
||||
tile_mask = seq_offset < max_seq_prefix_len
|
||||
|
||||
|
||||
Loading…
x
Reference in New Issue
Block a user