mirror of
https://git.datalinker.icu/vllm-project/vllm.git
synced 2026-03-21 06:59:09 +08:00
[Kernel] Raise verbose error and consolidate num_heads/num_kv_heads divisibility check (#19339)
Signed-off-by: 22quinn <33176974+22quinn@users.noreply.github.com>
This commit is contained in:
parent
ee1531bc38
commit
0b73736a0d
@ -10,6 +10,7 @@ import torch
|
||||
from tests.kernels.allclose_default import get_default_atol, get_default_rtol
|
||||
from tests.kernels.utils import opcheck
|
||||
from vllm import _custom_ops as ops
|
||||
from vllm.attention.layer import Attention, MultiHeadAttention
|
||||
from vllm.platforms import current_platform
|
||||
from vllm.utils import get_max_shared_memory_bytes
|
||||
|
||||
@ -506,3 +507,18 @@ def test_multi_query_kv_attention_with_alibi(
|
||||
device,
|
||||
use_alibi=True,
|
||||
)
|
||||
|
||||
|
||||
@pytest.mark.parametrize("attention_cls", [Attention, MultiHeadAttention])
|
||||
def test_num_heads_not_divisble_by_num_kv_heads(attention_cls: type) -> None:
|
||||
head_size = 64
|
||||
scale = float(1.0 / (head_size**0.5))
|
||||
num_heads = 16
|
||||
num_kv_heads = 5
|
||||
with pytest.raises(AssertionError):
|
||||
_ = attention_cls(
|
||||
num_heads=num_heads,
|
||||
head_size=head_size,
|
||||
scale=scale,
|
||||
num_kv_heads=num_kv_heads,
|
||||
)
|
||||
|
||||
@ -65,7 +65,6 @@ class BlocksparseParams:
|
||||
assert self.block_size > 0
|
||||
assert self.local_blocks >= 0
|
||||
assert self.vert_stride >= 1
|
||||
assert self.num_heads % self.num_kv_heads == 0
|
||||
|
||||
tp_size = get_tensor_model_parallel_world_size()
|
||||
tp_rank = get_tensor_model_parallel_rank()
|
||||
@ -329,9 +328,8 @@ class BlocksparseFlashAttentionImpl(AttentionImpl):
|
||||
self.head_size = head_size
|
||||
self.scale = float(scale)
|
||||
self.alibi_slopes = alibi_slopes
|
||||
self.num_kv_heads = num_heads if num_kv_heads is None else num_kv_heads
|
||||
self.num_kv_heads = num_kv_heads
|
||||
|
||||
assert self.num_heads % self.num_kv_heads == 0
|
||||
self.num_queries_per_kv = self.num_heads // self.num_kv_heads
|
||||
|
||||
self.local_blocks = self.blocksparse_params.local_blocks
|
||||
|
||||
@ -307,7 +307,6 @@ class DualChunkFlashAttentionImpl(FlashAttentionImpl):
|
||||
if sliding_window is not None else (-1, -1))
|
||||
self.kv_cache_dtype = kv_cache_dtype
|
||||
|
||||
assert self.num_heads % self.num_kv_heads == 0
|
||||
self.num_queries_per_kv = self.num_heads // self.num_kv_heads
|
||||
if sliding_window is not None:
|
||||
# NOTE(woosuk): flash-attn's sliding window does not work with
|
||||
|
||||
@ -654,7 +654,6 @@ class FlashAttentionImpl(AttentionImpl):
|
||||
logits_soft_cap = 0
|
||||
self.logits_soft_cap = logits_soft_cap
|
||||
|
||||
assert self.num_heads % self.num_kv_heads == 0
|
||||
self.num_queries_per_kv = self.num_heads // self.num_kv_heads
|
||||
|
||||
support_head_sizes = FlashAttentionBackend.get_supported_head_sizes()
|
||||
|
||||
@ -957,7 +957,6 @@ class FlashInferImpl(AttentionImpl):
|
||||
self.kv_cache_dtype = kv_cache_dtype
|
||||
self.logits_soft_cap = logits_soft_cap
|
||||
|
||||
assert self.num_heads % self.num_kv_heads == 0
|
||||
self.num_queries_per_kv = self.num_heads // self.num_kv_heads
|
||||
|
||||
if attn_type != AttentionType.DECODER:
|
||||
|
||||
@ -148,7 +148,6 @@ class HPUAttentionImpl(AttentionImpl, torch.nn.Module):
|
||||
alibi_slopes_tensor = torch.tensor(alibi_slopes,
|
||||
dtype=torch.bfloat16)
|
||||
self.alibi_slopes = alibi_slopes_tensor
|
||||
assert self.num_heads % self.num_kv_heads == 0
|
||||
self.num_queries_per_kv = self.num_heads // self.num_kv_heads
|
||||
|
||||
if self.prefill_impl == 'fsdpa':
|
||||
|
||||
@ -145,7 +145,6 @@ class IpexAttnBackendImpl(AttentionImpl[IpexAttnMetadata]):
|
||||
self.sliding_window = sliding_window
|
||||
self.kv_cache_dtype = kv_cache_dtype
|
||||
|
||||
assert self.num_heads % self.num_kv_heads == 0
|
||||
self.num_queries_per_kv = self.num_heads // self.num_kv_heads
|
||||
self.need_mask = (self.sliding_window is not None)
|
||||
if logits_soft_cap is None:
|
||||
|
||||
@ -121,9 +121,8 @@ class PallasAttentionBackendImpl(AttentionImpl):
|
||||
self.num_heads = num_heads
|
||||
self.head_size = head_size
|
||||
self.scale = float(scale)
|
||||
self.num_kv_heads = num_heads if num_kv_heads is None else num_kv_heads
|
||||
self.num_kv_heads = num_kv_heads
|
||||
|
||||
assert self.num_heads % self.num_kv_heads == 0
|
||||
self.num_queries_per_kv = self.num_heads // self.num_kv_heads
|
||||
self.logits_soft_cap = logits_soft_cap
|
||||
if head_size % 128 != 0:
|
||||
|
||||
@ -528,7 +528,6 @@ class ROCmFlashAttentionImpl(AttentionImpl):
|
||||
if sliding_window is not None else (-1, -1))
|
||||
self.kv_cache_dtype = kv_cache_dtype
|
||||
|
||||
assert self.num_heads % self.num_kv_heads == 0
|
||||
self.num_queries_per_kv = self.num_heads // self.num_kv_heads
|
||||
|
||||
self.paged_attn_module = _get_paged_attn_module()
|
||||
|
||||
@ -433,7 +433,6 @@ class TorchSDPABackendImpl(AttentionImpl[TorchSDPAMetadata]):
|
||||
self.sliding_window = sliding_window
|
||||
self.kv_cache_dtype = kv_cache_dtype
|
||||
|
||||
assert self.num_heads % self.num_kv_heads == 0
|
||||
self.num_queries_per_kv = self.num_heads // self.num_kv_heads
|
||||
self.need_mask = (self.alibi_slopes is not None
|
||||
or self.sliding_window is not None)
|
||||
|
||||
@ -415,7 +415,6 @@ class XFormersImpl(AttentionImpl[XFormersMetadata]):
|
||||
self.sliding_window = sliding_window
|
||||
self.kv_cache_dtype = kv_cache_dtype
|
||||
|
||||
assert self.num_heads % self.num_kv_heads == 0
|
||||
self.num_queries_per_kv = self.num_heads // self.num_kv_heads
|
||||
|
||||
supported_head_sizes = PagedAttention.get_supported_head_sizes()
|
||||
|
||||
@ -80,6 +80,9 @@ class Attention(nn.Module):
|
||||
calculate_kv_scales = False
|
||||
if num_kv_heads is None:
|
||||
num_kv_heads = num_heads
|
||||
assert num_heads % num_kv_heads == 0, \
|
||||
f"num_heads ({num_heads}) is not " \
|
||||
f"divisible by num_kv_heads ({num_kv_heads})"
|
||||
|
||||
# The default k/v_scale is set to 1.0. This is ignored
|
||||
# when kv-cache is not fp8, and should be used with
|
||||
@ -291,7 +294,9 @@ class MultiHeadAttention(nn.Module):
|
||||
self.scale = scale
|
||||
self.num_kv_heads = num_heads if num_kv_heads is None else num_kv_heads
|
||||
|
||||
assert self.num_heads % self.num_kv_heads == 0
|
||||
assert self.num_heads % self.num_kv_heads == 0, \
|
||||
f"num_heads ({self.num_heads}) is not " \
|
||||
f"divisible by num_kv_heads ({self.num_kv_heads})"
|
||||
self.num_queries_per_kv = self.num_heads // self.num_kv_heads
|
||||
|
||||
dtype = torch.get_default_dtype()
|
||||
|
||||
@ -545,7 +545,6 @@ class FlashAttentionImpl(AttentionImpl):
|
||||
self.logits_soft_cap = logits_soft_cap
|
||||
self.kv_sharing_target_layer_name = kv_sharing_target_layer_name
|
||||
|
||||
assert self.num_heads % self.num_kv_heads == 0
|
||||
self.num_queries_per_kv = self.num_heads // self.num_kv_heads
|
||||
|
||||
support_head_sizes = FlashAttentionBackend.get_supported_head_sizes()
|
||||
|
||||
@ -532,7 +532,6 @@ class FlashInferImpl(AttentionImpl):
|
||||
self.logits_soft_cap = logits_soft_cap
|
||||
self.kv_sharing_target_layer_name = kv_sharing_target_layer_name
|
||||
|
||||
assert self.num_heads % self.num_kv_heads == 0
|
||||
self.num_queries_per_kv = self.num_heads // self.num_kv_heads
|
||||
|
||||
if attn_type != AttentionType.DECODER:
|
||||
|
||||
@ -376,7 +376,6 @@ class FlexAttentionImpl(AttentionImpl):
|
||||
raise NotImplementedError(
|
||||
"FlexAttention does not support logits soft cap yet.")
|
||||
|
||||
assert self.num_heads % self.num_kv_heads == 0
|
||||
self.num_queries_per_kv = self.num_heads // self.num_kv_heads
|
||||
|
||||
if kv_sharing_target_layer_name is not None:
|
||||
|
||||
@ -131,7 +131,6 @@ class PallasAttentionBackendImpl(AttentionImpl):
|
||||
self.logits_soft_cap = logits_soft_cap
|
||||
self.kv_sharing_target_layer_name = kv_sharing_target_layer_name
|
||||
|
||||
assert self.num_heads % self.num_kv_heads == 0
|
||||
self.num_queries_per_kv = self.num_heads // self.num_kv_heads
|
||||
if head_size % 128 != 0:
|
||||
raise NotImplementedError("Head size must be a multiple of 128.")
|
||||
|
||||
@ -114,7 +114,6 @@ class TritonAttentionImpl(AttentionImpl):
|
||||
|
||||
self.use_irope = use_irope
|
||||
|
||||
assert self.num_heads % self.num_kv_heads == 0
|
||||
self.num_queries_per_kv = self.num_heads // self.num_kv_heads
|
||||
|
||||
support_head_sizes = TritonAttentionBackend.get_supported_head_sizes()
|
||||
|
||||
Loading…
x
Reference in New Issue
Block a user