diff --git a/vllm/attention/ops/triton_unified_attention.py b/vllm/attention/ops/triton_unified_attention.py index c65f09523a3c7..f9645f651351f 100644 --- a/vllm/attention/ops/triton_unified_attention.py +++ b/vllm/attention/ops/triton_unified_attention.py @@ -145,7 +145,19 @@ def kernel_unified_attention_2d( mask=query_mask_1, other=0.0) - num_blocks = cdiv_fn(seq_len, BLOCK_SIZE) + # compute the length of the longest sequence prefix spanned by any + # query token in the current q_block (q_block_local_idx) + max_seq_prefix_len = context_len + q_block_local_idx * BLOCK_Q + ( + BLOCK_M - 1) // num_queries_per_kv + 1 + + # adjust for potential padding in the last q_block by considering the + # actual sequence length + max_seq_prefix_len = tl.minimum(max_seq_prefix_len, seq_len) + + # calculate the number of tiles (blocks) that need to be processed to + # cover the longest sequence prefix (due to causal masking, blocks beyond + # this prefix can be skipped) + num_blocks = cdiv_fn(max_seq_prefix_len, BLOCK_SIZE) # iterate through tiles for j in range(0, num_blocks):