diff --git a/vllm/attention/ops/triton_unified_attention.py b/vllm/attention/ops/triton_unified_attention.py index 56ebed0f52448..250e9b3890444 100644 --- a/vllm/attention/ops/triton_unified_attention.py +++ b/vllm/attention/ops/triton_unified_attention.py @@ -674,7 +674,8 @@ def unified_attention( num_queries_per_kv = num_query_heads // num_kv_heads head_size = q.shape[2] - BLOCK_M = 16 + BLOCK_M = 16 if num_queries_per_kv <= 16 else triton.next_power_of_2( + num_queries_per_kv) BLOCK_Q = BLOCK_M // num_queries_per_kv # Ideally we would launch with kernel with: