From 0b73736a0d86a2fd0c2548a52eb3877611fa5915 Mon Sep 17 00:00:00 2001 From: 22quinn <33176974+22quinn@users.noreply.github.com> Date: Sat, 14 Jun 2025 22:43:48 -0700 Subject: [PATCH] [Kernel] Raise verbose error and consolidate `num_heads/num_kv_heads` divisibility check (#19339) Signed-off-by: 22quinn <33176974+22quinn@users.noreply.github.com> --- tests/kernels/attention/test_attention.py | 16 ++++++++++++++++ vllm/attention/backends/blocksparse_attn.py | 4 +--- vllm/attention/backends/dual_chunk_flash_attn.py | 1 - vllm/attention/backends/flash_attn.py | 1 - vllm/attention/backends/flashinfer.py | 1 - vllm/attention/backends/hpu_attn.py | 1 - vllm/attention/backends/ipex_attn.py | 1 - vllm/attention/backends/pallas.py | 3 +-- vllm/attention/backends/rocm_flash_attn.py | 1 - vllm/attention/backends/torch_sdpa.py | 1 - vllm/attention/backends/xformers.py | 1 - vllm/attention/layer.py | 7 ++++++- vllm/v1/attention/backends/flash_attn.py | 1 - vllm/v1/attention/backends/flashinfer.py | 1 - vllm/v1/attention/backends/flex_attention.py | 1 - vllm/v1/attention/backends/pallas.py | 1 - vllm/v1/attention/backends/triton_attn.py | 1 - 17 files changed, 24 insertions(+), 19 deletions(-) diff --git a/tests/kernels/attention/test_attention.py b/tests/kernels/attention/test_attention.py index 2d381a99be60c..7269d19183bf2 100644 --- a/tests/kernels/attention/test_attention.py +++ b/tests/kernels/attention/test_attention.py @@ -10,6 +10,7 @@ import torch from tests.kernels.allclose_default import get_default_atol, get_default_rtol from tests.kernels.utils import opcheck from vllm import _custom_ops as ops +from vllm.attention.layer import Attention, MultiHeadAttention from vllm.platforms import current_platform from vllm.utils import get_max_shared_memory_bytes @@ -506,3 +507,18 @@ def test_multi_query_kv_attention_with_alibi( device, use_alibi=True, ) + + +@pytest.mark.parametrize("attention_cls", [Attention, MultiHeadAttention]) +def test_num_heads_not_divisble_by_num_kv_heads(attention_cls: type) -> None: + head_size = 64 + scale = float(1.0 / (head_size**0.5)) + num_heads = 16 + num_kv_heads = 5 + with pytest.raises(AssertionError): + _ = attention_cls( + num_heads=num_heads, + head_size=head_size, + scale=scale, + num_kv_heads=num_kv_heads, + ) diff --git a/vllm/attention/backends/blocksparse_attn.py b/vllm/attention/backends/blocksparse_attn.py index 71415f49372fb..fe9738d804cb1 100644 --- a/vllm/attention/backends/blocksparse_attn.py +++ b/vllm/attention/backends/blocksparse_attn.py @@ -65,7 +65,6 @@ class BlocksparseParams: assert self.block_size > 0 assert self.local_blocks >= 0 assert self.vert_stride >= 1 - assert self.num_heads % self.num_kv_heads == 0 tp_size = get_tensor_model_parallel_world_size() tp_rank = get_tensor_model_parallel_rank() @@ -329,9 +328,8 @@ class BlocksparseFlashAttentionImpl(AttentionImpl): self.head_size = head_size self.scale = float(scale) self.alibi_slopes = alibi_slopes - self.num_kv_heads = num_heads if num_kv_heads is None else num_kv_heads + self.num_kv_heads = num_kv_heads - assert self.num_heads % self.num_kv_heads == 0 self.num_queries_per_kv = self.num_heads // self.num_kv_heads self.local_blocks = self.blocksparse_params.local_blocks diff --git a/vllm/attention/backends/dual_chunk_flash_attn.py b/vllm/attention/backends/dual_chunk_flash_attn.py index 55f57f37b100a..f62a43b441f23 100644 --- a/vllm/attention/backends/dual_chunk_flash_attn.py +++ b/vllm/attention/backends/dual_chunk_flash_attn.py @@ -307,7 +307,6 @@ class DualChunkFlashAttentionImpl(FlashAttentionImpl): if sliding_window is not None else (-1, -1)) self.kv_cache_dtype = kv_cache_dtype - assert self.num_heads % self.num_kv_heads == 0 self.num_queries_per_kv = self.num_heads // self.num_kv_heads if sliding_window is not None: # NOTE(woosuk): flash-attn's sliding window does not work with diff --git a/vllm/attention/backends/flash_attn.py b/vllm/attention/backends/flash_attn.py index 47c25d136c67b..bf8e373802f81 100755 --- a/vllm/attention/backends/flash_attn.py +++ b/vllm/attention/backends/flash_attn.py @@ -654,7 +654,6 @@ class FlashAttentionImpl(AttentionImpl): logits_soft_cap = 0 self.logits_soft_cap = logits_soft_cap - assert self.num_heads % self.num_kv_heads == 0 self.num_queries_per_kv = self.num_heads // self.num_kv_heads support_head_sizes = FlashAttentionBackend.get_supported_head_sizes() diff --git a/vllm/attention/backends/flashinfer.py b/vllm/attention/backends/flashinfer.py index ff73104787abe..a987dc53878dc 100644 --- a/vllm/attention/backends/flashinfer.py +++ b/vllm/attention/backends/flashinfer.py @@ -957,7 +957,6 @@ class FlashInferImpl(AttentionImpl): self.kv_cache_dtype = kv_cache_dtype self.logits_soft_cap = logits_soft_cap - assert self.num_heads % self.num_kv_heads == 0 self.num_queries_per_kv = self.num_heads // self.num_kv_heads if attn_type != AttentionType.DECODER: diff --git a/vllm/attention/backends/hpu_attn.py b/vllm/attention/backends/hpu_attn.py index 115e5ba1a20f4..bf778a1e5016d 100644 --- a/vllm/attention/backends/hpu_attn.py +++ b/vllm/attention/backends/hpu_attn.py @@ -148,7 +148,6 @@ class HPUAttentionImpl(AttentionImpl, torch.nn.Module): alibi_slopes_tensor = torch.tensor(alibi_slopes, dtype=torch.bfloat16) self.alibi_slopes = alibi_slopes_tensor - assert self.num_heads % self.num_kv_heads == 0 self.num_queries_per_kv = self.num_heads // self.num_kv_heads if self.prefill_impl == 'fsdpa': diff --git a/vllm/attention/backends/ipex_attn.py b/vllm/attention/backends/ipex_attn.py index 21f61cf70b28a..410ada3b0828b 100644 --- a/vllm/attention/backends/ipex_attn.py +++ b/vllm/attention/backends/ipex_attn.py @@ -145,7 +145,6 @@ class IpexAttnBackendImpl(AttentionImpl[IpexAttnMetadata]): self.sliding_window = sliding_window self.kv_cache_dtype = kv_cache_dtype - assert self.num_heads % self.num_kv_heads == 0 self.num_queries_per_kv = self.num_heads // self.num_kv_heads self.need_mask = (self.sliding_window is not None) if logits_soft_cap is None: diff --git a/vllm/attention/backends/pallas.py b/vllm/attention/backends/pallas.py index c5c080297cea3..c900666955a32 100644 --- a/vllm/attention/backends/pallas.py +++ b/vllm/attention/backends/pallas.py @@ -121,9 +121,8 @@ class PallasAttentionBackendImpl(AttentionImpl): self.num_heads = num_heads self.head_size = head_size self.scale = float(scale) - self.num_kv_heads = num_heads if num_kv_heads is None else num_kv_heads + self.num_kv_heads = num_kv_heads - assert self.num_heads % self.num_kv_heads == 0 self.num_queries_per_kv = self.num_heads // self.num_kv_heads self.logits_soft_cap = logits_soft_cap if head_size % 128 != 0: diff --git a/vllm/attention/backends/rocm_flash_attn.py b/vllm/attention/backends/rocm_flash_attn.py index 8f1da84cd4834..1e2c21f4e69d6 100644 --- a/vllm/attention/backends/rocm_flash_attn.py +++ b/vllm/attention/backends/rocm_flash_attn.py @@ -528,7 +528,6 @@ class ROCmFlashAttentionImpl(AttentionImpl): if sliding_window is not None else (-1, -1)) self.kv_cache_dtype = kv_cache_dtype - assert self.num_heads % self.num_kv_heads == 0 self.num_queries_per_kv = self.num_heads // self.num_kv_heads self.paged_attn_module = _get_paged_attn_module() diff --git a/vllm/attention/backends/torch_sdpa.py b/vllm/attention/backends/torch_sdpa.py index 9d7e735dd41db..3e1336a5ac3b2 100644 --- a/vllm/attention/backends/torch_sdpa.py +++ b/vllm/attention/backends/torch_sdpa.py @@ -433,7 +433,6 @@ class TorchSDPABackendImpl(AttentionImpl[TorchSDPAMetadata]): self.sliding_window = sliding_window self.kv_cache_dtype = kv_cache_dtype - assert self.num_heads % self.num_kv_heads == 0 self.num_queries_per_kv = self.num_heads // self.num_kv_heads self.need_mask = (self.alibi_slopes is not None or self.sliding_window is not None) diff --git a/vllm/attention/backends/xformers.py b/vllm/attention/backends/xformers.py index dfdc8ee6402d5..b583240c73c41 100644 --- a/vllm/attention/backends/xformers.py +++ b/vllm/attention/backends/xformers.py @@ -415,7 +415,6 @@ class XFormersImpl(AttentionImpl[XFormersMetadata]): self.sliding_window = sliding_window self.kv_cache_dtype = kv_cache_dtype - assert self.num_heads % self.num_kv_heads == 0 self.num_queries_per_kv = self.num_heads // self.num_kv_heads supported_head_sizes = PagedAttention.get_supported_head_sizes() diff --git a/vllm/attention/layer.py b/vllm/attention/layer.py index 3bbe276e0cbe4..6d9c6f51b34df 100644 --- a/vllm/attention/layer.py +++ b/vllm/attention/layer.py @@ -80,6 +80,9 @@ class Attention(nn.Module): calculate_kv_scales = False if num_kv_heads is None: num_kv_heads = num_heads + assert num_heads % num_kv_heads == 0, \ + f"num_heads ({num_heads}) is not " \ + f"divisible by num_kv_heads ({num_kv_heads})" # The default k/v_scale is set to 1.0. This is ignored # when kv-cache is not fp8, and should be used with @@ -291,7 +294,9 @@ class MultiHeadAttention(nn.Module): self.scale = scale self.num_kv_heads = num_heads if num_kv_heads is None else num_kv_heads - assert self.num_heads % self.num_kv_heads == 0 + assert self.num_heads % self.num_kv_heads == 0, \ + f"num_heads ({self.num_heads}) is not " \ + f"divisible by num_kv_heads ({self.num_kv_heads})" self.num_queries_per_kv = self.num_heads // self.num_kv_heads dtype = torch.get_default_dtype() diff --git a/vllm/v1/attention/backends/flash_attn.py b/vllm/v1/attention/backends/flash_attn.py index 630ac13228f14..8b7745ceddd4e 100755 --- a/vllm/v1/attention/backends/flash_attn.py +++ b/vllm/v1/attention/backends/flash_attn.py @@ -545,7 +545,6 @@ class FlashAttentionImpl(AttentionImpl): self.logits_soft_cap = logits_soft_cap self.kv_sharing_target_layer_name = kv_sharing_target_layer_name - assert self.num_heads % self.num_kv_heads == 0 self.num_queries_per_kv = self.num_heads // self.num_kv_heads support_head_sizes = FlashAttentionBackend.get_supported_head_sizes() diff --git a/vllm/v1/attention/backends/flashinfer.py b/vllm/v1/attention/backends/flashinfer.py index 12547b99e5b6e..b2f54f37a6e19 100755 --- a/vllm/v1/attention/backends/flashinfer.py +++ b/vllm/v1/attention/backends/flashinfer.py @@ -532,7 +532,6 @@ class FlashInferImpl(AttentionImpl): self.logits_soft_cap = logits_soft_cap self.kv_sharing_target_layer_name = kv_sharing_target_layer_name - assert self.num_heads % self.num_kv_heads == 0 self.num_queries_per_kv = self.num_heads // self.num_kv_heads if attn_type != AttentionType.DECODER: diff --git a/vllm/v1/attention/backends/flex_attention.py b/vllm/v1/attention/backends/flex_attention.py index c8cb1481c8b46..a572b89470f48 100644 --- a/vllm/v1/attention/backends/flex_attention.py +++ b/vllm/v1/attention/backends/flex_attention.py @@ -376,7 +376,6 @@ class FlexAttentionImpl(AttentionImpl): raise NotImplementedError( "FlexAttention does not support logits soft cap yet.") - assert self.num_heads % self.num_kv_heads == 0 self.num_queries_per_kv = self.num_heads // self.num_kv_heads if kv_sharing_target_layer_name is not None: diff --git a/vllm/v1/attention/backends/pallas.py b/vllm/v1/attention/backends/pallas.py index 62c72f43f147e..7a6d8c0f85d7b 100644 --- a/vllm/v1/attention/backends/pallas.py +++ b/vllm/v1/attention/backends/pallas.py @@ -131,7 +131,6 @@ class PallasAttentionBackendImpl(AttentionImpl): self.logits_soft_cap = logits_soft_cap self.kv_sharing_target_layer_name = kv_sharing_target_layer_name - assert self.num_heads % self.num_kv_heads == 0 self.num_queries_per_kv = self.num_heads // self.num_kv_heads if head_size % 128 != 0: raise NotImplementedError("Head size must be a multiple of 128.") diff --git a/vllm/v1/attention/backends/triton_attn.py b/vllm/v1/attention/backends/triton_attn.py index 6b67d9932e9d8..9782ec087babb 100644 --- a/vllm/v1/attention/backends/triton_attn.py +++ b/vllm/v1/attention/backends/triton_attn.py @@ -114,7 +114,6 @@ class TritonAttentionImpl(AttentionImpl): self.use_irope = use_irope - assert self.num_heads % self.num_kv_heads == 0 self.num_queries_per_kv = self.num_heads // self.num_kv_heads support_head_sizes = TritonAttentionBackend.get_supported_head_sizes()