diff --git a/vllm/attention/ops/triton_unified_attention.py b/vllm/attention/ops/triton_unified_attention.py index f61c8e9b89c24..e60af7434d25c 100644 --- a/vllm/attention/ops/triton_unified_attention.py +++ b/vllm/attention/ops/triton_unified_attention.py @@ -189,9 +189,14 @@ def kernel_unified_attention_2d( + 1 ) - # adjust for potential padding in the last q_block by considering the - # actual sequence length - max_seq_prefix_len = tl.minimum(max_seq_prefix_len, seq_len) + if USE_MM_PREFIX: + # image bidirectional attention ranges require a full range + # including q_block padding to make sure doc mask is correct + max_seq_prefix_len = tl.maximum(max_seq_prefix_len, seq_len) + else: + # adjust for potential padding in the last q_block by considering the + # actual sequence length + max_seq_prefix_len = tl.minimum(max_seq_prefix_len, seq_len) # calculate the number of tiles that need to be processed to # cover the longest sequence prefix (due to causal masking, tiles beyond @@ -202,7 +207,8 @@ def kernel_unified_attention_2d( # Default: keep previous global behavior tile_start = 0 tile_end = num_tiles - if SLIDING_WINDOW > 0: + # TODO(Isotr0py): sliding window pruning with image bidirectional mask + if SLIDING_WINDOW > 0 and not USE_MM_PREFIX: # Query rows covered by this Q-block qpos_lo = q_block_local_idx * BLOCK_Q qpos_hi = tl.minimum(