[Bugfix] Fix incorrect tiles creation for mm prefix triton attention (#30974)

Signed-off-by: Isotr0py <mozf@mail2.sysu.edu.cn>
This commit is contained in:
Isotr0py 2025-12-19 16:00:33 +08:00 committed by GitHub
parent 4924ac582c
commit ac1c934276
No known key found for this signature in database
GPG Key ID: B5690EEEBB952194

View File

@ -189,9 +189,14 @@ def kernel_unified_attention_2d(
+ 1
)
# adjust for potential padding in the last q_block by considering the
# actual sequence length
max_seq_prefix_len = tl.minimum(max_seq_prefix_len, seq_len)
if USE_MM_PREFIX:
# image bidirectional attention ranges require a full range
# including q_block padding to make sure doc mask is correct
max_seq_prefix_len = tl.maximum(max_seq_prefix_len, seq_len)
else:
# adjust for potential padding in the last q_block by considering the
# actual sequence length
max_seq_prefix_len = tl.minimum(max_seq_prefix_len, seq_len)
# calculate the number of tiles that need to be processed to
# cover the longest sequence prefix (due to causal masking, tiles beyond
@ -202,7 +207,8 @@ def kernel_unified_attention_2d(
# Default: keep previous global behavior
tile_start = 0
tile_end = num_tiles
if SLIDING_WINDOW > 0:
# TODO(Isotr0py): sliding window pruning with image bidirectional mask
if SLIDING_WINDOW > 0 and not USE_MM_PREFIX:
# Query rows covered by this Q-block
qpos_lo = q_block_local_idx * BLOCK_Q
qpos_hi = tl.minimum(