From 17b9f3a83d8edb7317d5803abea7b22013974633 Mon Sep 17 00:00:00 2001 From: qizixi <22851944+zixi-qi@users.noreply.github.com> Date: Fri, 19 Sep 2025 12:07:26 -0700 Subject: [PATCH] Optimize triton unified attention performance for sliding window attention (#24390) Signed-off-by: zixi-qi Signed-off-by: yewentao256 --- .../test_triton_unified_attention.py | 2 +- .../attention/ops/triton_unified_attention.py | 26 +++++++++++++++++-- 2 files changed, 25 insertions(+), 3 deletions(-) diff --git a/tests/kernels/attention/test_triton_unified_attention.py b/tests/kernels/attention/test_triton_unified_attention.py index ab91560e995c8..5cff29b15aa3f 100644 --- a/tests/kernels/attention/test_triton_unified_attention.py +++ b/tests/kernels/attention/test_triton_unified_attention.py @@ -83,7 +83,7 @@ def ref_paged_attn( @pytest.mark.parametrize("num_heads", NUM_HEADS) @pytest.mark.parametrize("head_size", HEAD_SIZES) @pytest.mark.parametrize("block_size", BLOCK_SIZES) -@pytest.mark.parametrize("sliding_window", [None, 256]) +@pytest.mark.parametrize("sliding_window", [None, 64, 128, 256]) @pytest.mark.parametrize("dtype", DTYPES) @pytest.mark.parametrize("soft_cap", [None, 50.0]) @pytest.mark.parametrize("num_blocks", NUM_BLOCKS) diff --git a/vllm/attention/ops/triton_unified_attention.py b/vllm/attention/ops/triton_unified_attention.py index 591b68bfa6468..9e7cafc174287 100644 --- a/vllm/attention/ops/triton_unified_attention.py +++ b/vllm/attention/ops/triton_unified_attention.py @@ -184,8 +184,30 @@ def kernel_unified_attention_2d( # this prefix can be skipped) num_tiles = cdiv_fn(max_seq_prefix_len, TILE_SIZE) - # iterate through tiles - for j in range(0, num_tiles): + # ---- Sliding-window tile pruning -------------------- + # Default: keep previous global behavior + tile_start = 0 + tile_end = num_tiles + if SLIDING_WINDOW > 0: + # Query rows covered by this Q-block + qpos_lo = q_block_local_idx * BLOCK_Q + qpos_hi = tl.minimum( + qpos_lo + (BLOCK_M - 1) // num_queries_per_kv, + cur_batch_query_len - 1, + ) + # For sliding window, each query position q can only attend to + # keys in the range [q_abs - SLIDING_WINDOW + 1, q_abs] + # where q_abs = context_len + q + # The union of allowed key positions for this Q-block is: + # [context_len + qpos_lo - SLIDING_WINDOW + 1, context_len + qpos_hi] + first_allowed_key = context_len + qpos_lo - SLIDING_WINDOW + 1 + last_allowed_key = context_len + qpos_hi + # Convert to tile indices and clamp + tile_start = tl.maximum(0, first_allowed_key // TILE_SIZE) + tile_end = tl.minimum((last_allowed_key // TILE_SIZE) + 1, num_tiles) + + # iterate through tiles (now limited to the sliding window range) + for j in range(tile_start, tile_end): seq_offset = j * TILE_SIZE + offs_t tile_mask = seq_offset < max_seq_prefix_len