[Bugfix] Fixing division by zero in triton_attn if query_heads/kv_heads > 16 (#23424)

Signed-off-by: Burkhard Ringlein <ngl@zurich.ibm.com>
This commit is contained in:
Burkhard Ringlein 2025-09-03 17:01:09 +02:00 committed by GitHub
parent 4ba0c587ba
commit 6d80ae83e1
No known key found for this signature in database
GPG Key ID: B5690EEEBB952194

View File

@ -674,7 +674,8 @@ def unified_attention(
num_queries_per_kv = num_query_heads // num_kv_heads
head_size = q.shape[2]
BLOCK_M = 16
BLOCK_M = 16 if num_queries_per_kv <= 16 else triton.next_power_of_2(
num_queries_per_kv)
BLOCK_Q = BLOCK_M // num_queries_per_kv
# Ideally we would launch with kernel with: