[Bugfix] [Kernel] Triton attention kernels: mask out V blocks that fall outside sliding window (#30887)

Signed-off-by: Thomas Parnell <tpa@zurich.ibm.com>
This commit is contained in:
Thomas Parnell 2025-12-19 14:39:54 +01:00 committed by GitHub
parent bd2b52fc2d
commit b5545d9d5c
No known key found for this signature in database
GPG Key ID: B5690EEEBB952194

View File

@ -363,6 +363,12 @@ def kernel_unified_attention_2d(
L = L * alpha + l_j
M = m_j
if SLIDING_WINDOW:
qpos_lo = q_block_local_idx * BLOCK_Q
V = tl.where(
(context_len + qpos_lo - seq_offset[:, None]) < SLIDING_WINDOW, V, 0.0
)
# acc : (BLOCK_M, HEAD_SIZE_PADDED)
acc += tl.dot(P.to(V.dtype), V)
@ -678,6 +684,12 @@ def kernel_unified_attention_3d(
L = L * alpha + l_j
M = m_j
if SLIDING_WINDOW:
qpos_lo = q_block_local_idx * BLOCK_Q
V = tl.where(
(context_len + qpos_lo - seq_offset[:, None]) < SLIDING_WINDOW, V, 0.0
)
# acc : (BLOCK_M, HEAD_SIZE_PADDED)
acc += tl.dot(P.to(V.dtype), V)