mirror of
https://git.datalinker.icu/vllm-project/vllm.git
synced 2026-04-15 21:47:13 +08:00
[Bugfix] [Kernel] Triton attention kernels: mask out V blocks that fall outside sliding window (#30887)
Signed-off-by: Thomas Parnell <tpa@zurich.ibm.com>
This commit is contained in:
parent
bd2b52fc2d
commit
b5545d9d5c
@ -363,6 +363,12 @@ def kernel_unified_attention_2d(
|
||||
L = L * alpha + l_j
|
||||
M = m_j
|
||||
|
||||
if SLIDING_WINDOW:
|
||||
qpos_lo = q_block_local_idx * BLOCK_Q
|
||||
V = tl.where(
|
||||
(context_len + qpos_lo - seq_offset[:, None]) < SLIDING_WINDOW, V, 0.0
|
||||
)
|
||||
|
||||
# acc : (BLOCK_M, HEAD_SIZE_PADDED)
|
||||
acc += tl.dot(P.to(V.dtype), V)
|
||||
|
||||
@ -678,6 +684,12 @@ def kernel_unified_attention_3d(
|
||||
L = L * alpha + l_j
|
||||
M = m_j
|
||||
|
||||
if SLIDING_WINDOW:
|
||||
qpos_lo = q_block_local_idx * BLOCK_Q
|
||||
V = tl.where(
|
||||
(context_len + qpos_lo - seq_offset[:, None]) < SLIDING_WINDOW, V, 0.0
|
||||
)
|
||||
|
||||
# acc : (BLOCK_M, HEAD_SIZE_PADDED)
|
||||
acc += tl.dot(P.to(V.dtype), V)
|
||||
|
||||
|
||||
Loading…
x
Reference in New Issue
Block a user