diff --git a/vllm/attention/ops/triton_unified_attention.py b/vllm/attention/ops/triton_unified_attention.py index e60af7434d25c..c946dbd8a2c4e 100644 --- a/vllm/attention/ops/triton_unified_attention.py +++ b/vllm/attention/ops/triton_unified_attention.py @@ -363,6 +363,12 @@ def kernel_unified_attention_2d( L = L * alpha + l_j M = m_j + if SLIDING_WINDOW: + qpos_lo = q_block_local_idx * BLOCK_Q + V = tl.where( + (context_len + qpos_lo - seq_offset[:, None]) < SLIDING_WINDOW, V, 0.0 + ) + # acc : (BLOCK_M, HEAD_SIZE_PADDED) acc += tl.dot(P.to(V.dtype), V) @@ -678,6 +684,12 @@ def kernel_unified_attention_3d( L = L * alpha + l_j M = m_j + if SLIDING_WINDOW: + qpos_lo = q_block_local_idx * BLOCK_Q + V = tl.where( + (context_len + qpos_lo - seq_offset[:, None]) < SLIDING_WINDOW, V, 0.0 + ) + # acc : (BLOCK_M, HEAD_SIZE_PADDED) acc += tl.dot(P.to(V.dtype), V)