mirror of
https://git.datalinker.icu/vllm-project/vllm.git
synced 2026-05-25 21:24:27 +08:00
[Kernel] Raise verbose error and consolidate num_heads/num_kv_heads divisibility check (#19339)
Signed-off-by: 22quinn <33176974+22quinn@users.noreply.github.com>
This commit is contained in:
parent
ee1531bc38
commit
0b73736a0d
@ -10,6 +10,7 @@ import torch
|
|||||||
from tests.kernels.allclose_default import get_default_atol, get_default_rtol
|
from tests.kernels.allclose_default import get_default_atol, get_default_rtol
|
||||||
from tests.kernels.utils import opcheck
|
from tests.kernels.utils import opcheck
|
||||||
from vllm import _custom_ops as ops
|
from vllm import _custom_ops as ops
|
||||||
|
from vllm.attention.layer import Attention, MultiHeadAttention
|
||||||
from vllm.platforms import current_platform
|
from vllm.platforms import current_platform
|
||||||
from vllm.utils import get_max_shared_memory_bytes
|
from vllm.utils import get_max_shared_memory_bytes
|
||||||
|
|
||||||
@ -506,3 +507,18 @@ def test_multi_query_kv_attention_with_alibi(
|
|||||||
device,
|
device,
|
||||||
use_alibi=True,
|
use_alibi=True,
|
||||||
)
|
)
|
||||||
|
|
||||||
|
|
||||||
|
@pytest.mark.parametrize("attention_cls", [Attention, MultiHeadAttention])
|
||||||
|
def test_num_heads_not_divisble_by_num_kv_heads(attention_cls: type) -> None:
|
||||||
|
head_size = 64
|
||||||
|
scale = float(1.0 / (head_size**0.5))
|
||||||
|
num_heads = 16
|
||||||
|
num_kv_heads = 5
|
||||||
|
with pytest.raises(AssertionError):
|
||||||
|
_ = attention_cls(
|
||||||
|
num_heads=num_heads,
|
||||||
|
head_size=head_size,
|
||||||
|
scale=scale,
|
||||||
|
num_kv_heads=num_kv_heads,
|
||||||
|
)
|
||||||
|
|||||||
@ -65,7 +65,6 @@ class BlocksparseParams:
|
|||||||
assert self.block_size > 0
|
assert self.block_size > 0
|
||||||
assert self.local_blocks >= 0
|
assert self.local_blocks >= 0
|
||||||
assert self.vert_stride >= 1
|
assert self.vert_stride >= 1
|
||||||
assert self.num_heads % self.num_kv_heads == 0
|
|
||||||
|
|
||||||
tp_size = get_tensor_model_parallel_world_size()
|
tp_size = get_tensor_model_parallel_world_size()
|
||||||
tp_rank = get_tensor_model_parallel_rank()
|
tp_rank = get_tensor_model_parallel_rank()
|
||||||
@ -329,9 +328,8 @@ class BlocksparseFlashAttentionImpl(AttentionImpl):
|
|||||||
self.head_size = head_size
|
self.head_size = head_size
|
||||||
self.scale = float(scale)
|
self.scale = float(scale)
|
||||||
self.alibi_slopes = alibi_slopes
|
self.alibi_slopes = alibi_slopes
|
||||||
self.num_kv_heads = num_heads if num_kv_heads is None else num_kv_heads
|
self.num_kv_heads = num_kv_heads
|
||||||
|
|
||||||
assert self.num_heads % self.num_kv_heads == 0
|
|
||||||
self.num_queries_per_kv = self.num_heads // self.num_kv_heads
|
self.num_queries_per_kv = self.num_heads // self.num_kv_heads
|
||||||
|
|
||||||
self.local_blocks = self.blocksparse_params.local_blocks
|
self.local_blocks = self.blocksparse_params.local_blocks
|
||||||
|
|||||||
@ -307,7 +307,6 @@ class DualChunkFlashAttentionImpl(FlashAttentionImpl):
|
|||||||
if sliding_window is not None else (-1, -1))
|
if sliding_window is not None else (-1, -1))
|
||||||
self.kv_cache_dtype = kv_cache_dtype
|
self.kv_cache_dtype = kv_cache_dtype
|
||||||
|
|
||||||
assert self.num_heads % self.num_kv_heads == 0
|
|
||||||
self.num_queries_per_kv = self.num_heads // self.num_kv_heads
|
self.num_queries_per_kv = self.num_heads // self.num_kv_heads
|
||||||
if sliding_window is not None:
|
if sliding_window is not None:
|
||||||
# NOTE(woosuk): flash-attn's sliding window does not work with
|
# NOTE(woosuk): flash-attn's sliding window does not work with
|
||||||
|
|||||||
@ -654,7 +654,6 @@ class FlashAttentionImpl(AttentionImpl):
|
|||||||
logits_soft_cap = 0
|
logits_soft_cap = 0
|
||||||
self.logits_soft_cap = logits_soft_cap
|
self.logits_soft_cap = logits_soft_cap
|
||||||
|
|
||||||
assert self.num_heads % self.num_kv_heads == 0
|
|
||||||
self.num_queries_per_kv = self.num_heads // self.num_kv_heads
|
self.num_queries_per_kv = self.num_heads // self.num_kv_heads
|
||||||
|
|
||||||
support_head_sizes = FlashAttentionBackend.get_supported_head_sizes()
|
support_head_sizes = FlashAttentionBackend.get_supported_head_sizes()
|
||||||
|
|||||||
@ -957,7 +957,6 @@ class FlashInferImpl(AttentionImpl):
|
|||||||
self.kv_cache_dtype = kv_cache_dtype
|
self.kv_cache_dtype = kv_cache_dtype
|
||||||
self.logits_soft_cap = logits_soft_cap
|
self.logits_soft_cap = logits_soft_cap
|
||||||
|
|
||||||
assert self.num_heads % self.num_kv_heads == 0
|
|
||||||
self.num_queries_per_kv = self.num_heads // self.num_kv_heads
|
self.num_queries_per_kv = self.num_heads // self.num_kv_heads
|
||||||
|
|
||||||
if attn_type != AttentionType.DECODER:
|
if attn_type != AttentionType.DECODER:
|
||||||
|
|||||||
@ -148,7 +148,6 @@ class HPUAttentionImpl(AttentionImpl, torch.nn.Module):
|
|||||||
alibi_slopes_tensor = torch.tensor(alibi_slopes,
|
alibi_slopes_tensor = torch.tensor(alibi_slopes,
|
||||||
dtype=torch.bfloat16)
|
dtype=torch.bfloat16)
|
||||||
self.alibi_slopes = alibi_slopes_tensor
|
self.alibi_slopes = alibi_slopes_tensor
|
||||||
assert self.num_heads % self.num_kv_heads == 0
|
|
||||||
self.num_queries_per_kv = self.num_heads // self.num_kv_heads
|
self.num_queries_per_kv = self.num_heads // self.num_kv_heads
|
||||||
|
|
||||||
if self.prefill_impl == 'fsdpa':
|
if self.prefill_impl == 'fsdpa':
|
||||||
|
|||||||
@ -145,7 +145,6 @@ class IpexAttnBackendImpl(AttentionImpl[IpexAttnMetadata]):
|
|||||||
self.sliding_window = sliding_window
|
self.sliding_window = sliding_window
|
||||||
self.kv_cache_dtype = kv_cache_dtype
|
self.kv_cache_dtype = kv_cache_dtype
|
||||||
|
|
||||||
assert self.num_heads % self.num_kv_heads == 0
|
|
||||||
self.num_queries_per_kv = self.num_heads // self.num_kv_heads
|
self.num_queries_per_kv = self.num_heads // self.num_kv_heads
|
||||||
self.need_mask = (self.sliding_window is not None)
|
self.need_mask = (self.sliding_window is not None)
|
||||||
if logits_soft_cap is None:
|
if logits_soft_cap is None:
|
||||||
|
|||||||
@ -121,9 +121,8 @@ class PallasAttentionBackendImpl(AttentionImpl):
|
|||||||
self.num_heads = num_heads
|
self.num_heads = num_heads
|
||||||
self.head_size = head_size
|
self.head_size = head_size
|
||||||
self.scale = float(scale)
|
self.scale = float(scale)
|
||||||
self.num_kv_heads = num_heads if num_kv_heads is None else num_kv_heads
|
self.num_kv_heads = num_kv_heads
|
||||||
|
|
||||||
assert self.num_heads % self.num_kv_heads == 0
|
|
||||||
self.num_queries_per_kv = self.num_heads // self.num_kv_heads
|
self.num_queries_per_kv = self.num_heads // self.num_kv_heads
|
||||||
self.logits_soft_cap = logits_soft_cap
|
self.logits_soft_cap = logits_soft_cap
|
||||||
if head_size % 128 != 0:
|
if head_size % 128 != 0:
|
||||||
|
|||||||
@ -528,7 +528,6 @@ class ROCmFlashAttentionImpl(AttentionImpl):
|
|||||||
if sliding_window is not None else (-1, -1))
|
if sliding_window is not None else (-1, -1))
|
||||||
self.kv_cache_dtype = kv_cache_dtype
|
self.kv_cache_dtype = kv_cache_dtype
|
||||||
|
|
||||||
assert self.num_heads % self.num_kv_heads == 0
|
|
||||||
self.num_queries_per_kv = self.num_heads // self.num_kv_heads
|
self.num_queries_per_kv = self.num_heads // self.num_kv_heads
|
||||||
|
|
||||||
self.paged_attn_module = _get_paged_attn_module()
|
self.paged_attn_module = _get_paged_attn_module()
|
||||||
|
|||||||
@ -433,7 +433,6 @@ class TorchSDPABackendImpl(AttentionImpl[TorchSDPAMetadata]):
|
|||||||
self.sliding_window = sliding_window
|
self.sliding_window = sliding_window
|
||||||
self.kv_cache_dtype = kv_cache_dtype
|
self.kv_cache_dtype = kv_cache_dtype
|
||||||
|
|
||||||
assert self.num_heads % self.num_kv_heads == 0
|
|
||||||
self.num_queries_per_kv = self.num_heads // self.num_kv_heads
|
self.num_queries_per_kv = self.num_heads // self.num_kv_heads
|
||||||
self.need_mask = (self.alibi_slopes is not None
|
self.need_mask = (self.alibi_slopes is not None
|
||||||
or self.sliding_window is not None)
|
or self.sliding_window is not None)
|
||||||
|
|||||||
@ -415,7 +415,6 @@ class XFormersImpl(AttentionImpl[XFormersMetadata]):
|
|||||||
self.sliding_window = sliding_window
|
self.sliding_window = sliding_window
|
||||||
self.kv_cache_dtype = kv_cache_dtype
|
self.kv_cache_dtype = kv_cache_dtype
|
||||||
|
|
||||||
assert self.num_heads % self.num_kv_heads == 0
|
|
||||||
self.num_queries_per_kv = self.num_heads // self.num_kv_heads
|
self.num_queries_per_kv = self.num_heads // self.num_kv_heads
|
||||||
|
|
||||||
supported_head_sizes = PagedAttention.get_supported_head_sizes()
|
supported_head_sizes = PagedAttention.get_supported_head_sizes()
|
||||||
|
|||||||
@ -80,6 +80,9 @@ class Attention(nn.Module):
|
|||||||
calculate_kv_scales = False
|
calculate_kv_scales = False
|
||||||
if num_kv_heads is None:
|
if num_kv_heads is None:
|
||||||
num_kv_heads = num_heads
|
num_kv_heads = num_heads
|
||||||
|
assert num_heads % num_kv_heads == 0, \
|
||||||
|
f"num_heads ({num_heads}) is not " \
|
||||||
|
f"divisible by num_kv_heads ({num_kv_heads})"
|
||||||
|
|
||||||
# The default k/v_scale is set to 1.0. This is ignored
|
# The default k/v_scale is set to 1.0. This is ignored
|
||||||
# when kv-cache is not fp8, and should be used with
|
# when kv-cache is not fp8, and should be used with
|
||||||
@ -291,7 +294,9 @@ class MultiHeadAttention(nn.Module):
|
|||||||
self.scale = scale
|
self.scale = scale
|
||||||
self.num_kv_heads = num_heads if num_kv_heads is None else num_kv_heads
|
self.num_kv_heads = num_heads if num_kv_heads is None else num_kv_heads
|
||||||
|
|
||||||
assert self.num_heads % self.num_kv_heads == 0
|
assert self.num_heads % self.num_kv_heads == 0, \
|
||||||
|
f"num_heads ({self.num_heads}) is not " \
|
||||||
|
f"divisible by num_kv_heads ({self.num_kv_heads})"
|
||||||
self.num_queries_per_kv = self.num_heads // self.num_kv_heads
|
self.num_queries_per_kv = self.num_heads // self.num_kv_heads
|
||||||
|
|
||||||
dtype = torch.get_default_dtype()
|
dtype = torch.get_default_dtype()
|
||||||
|
|||||||
@ -545,7 +545,6 @@ class FlashAttentionImpl(AttentionImpl):
|
|||||||
self.logits_soft_cap = logits_soft_cap
|
self.logits_soft_cap = logits_soft_cap
|
||||||
self.kv_sharing_target_layer_name = kv_sharing_target_layer_name
|
self.kv_sharing_target_layer_name = kv_sharing_target_layer_name
|
||||||
|
|
||||||
assert self.num_heads % self.num_kv_heads == 0
|
|
||||||
self.num_queries_per_kv = self.num_heads // self.num_kv_heads
|
self.num_queries_per_kv = self.num_heads // self.num_kv_heads
|
||||||
|
|
||||||
support_head_sizes = FlashAttentionBackend.get_supported_head_sizes()
|
support_head_sizes = FlashAttentionBackend.get_supported_head_sizes()
|
||||||
|
|||||||
@ -532,7 +532,6 @@ class FlashInferImpl(AttentionImpl):
|
|||||||
self.logits_soft_cap = logits_soft_cap
|
self.logits_soft_cap = logits_soft_cap
|
||||||
self.kv_sharing_target_layer_name = kv_sharing_target_layer_name
|
self.kv_sharing_target_layer_name = kv_sharing_target_layer_name
|
||||||
|
|
||||||
assert self.num_heads % self.num_kv_heads == 0
|
|
||||||
self.num_queries_per_kv = self.num_heads // self.num_kv_heads
|
self.num_queries_per_kv = self.num_heads // self.num_kv_heads
|
||||||
|
|
||||||
if attn_type != AttentionType.DECODER:
|
if attn_type != AttentionType.DECODER:
|
||||||
|
|||||||
@ -376,7 +376,6 @@ class FlexAttentionImpl(AttentionImpl):
|
|||||||
raise NotImplementedError(
|
raise NotImplementedError(
|
||||||
"FlexAttention does not support logits soft cap yet.")
|
"FlexAttention does not support logits soft cap yet.")
|
||||||
|
|
||||||
assert self.num_heads % self.num_kv_heads == 0
|
|
||||||
self.num_queries_per_kv = self.num_heads // self.num_kv_heads
|
self.num_queries_per_kv = self.num_heads // self.num_kv_heads
|
||||||
|
|
||||||
if kv_sharing_target_layer_name is not None:
|
if kv_sharing_target_layer_name is not None:
|
||||||
|
|||||||
@ -131,7 +131,6 @@ class PallasAttentionBackendImpl(AttentionImpl):
|
|||||||
self.logits_soft_cap = logits_soft_cap
|
self.logits_soft_cap = logits_soft_cap
|
||||||
self.kv_sharing_target_layer_name = kv_sharing_target_layer_name
|
self.kv_sharing_target_layer_name = kv_sharing_target_layer_name
|
||||||
|
|
||||||
assert self.num_heads % self.num_kv_heads == 0
|
|
||||||
self.num_queries_per_kv = self.num_heads // self.num_kv_heads
|
self.num_queries_per_kv = self.num_heads // self.num_kv_heads
|
||||||
if head_size % 128 != 0:
|
if head_size % 128 != 0:
|
||||||
raise NotImplementedError("Head size must be a multiple of 128.")
|
raise NotImplementedError("Head size must be a multiple of 128.")
|
||||||
|
|||||||
@ -114,7 +114,6 @@ class TritonAttentionImpl(AttentionImpl):
|
|||||||
|
|
||||||
self.use_irope = use_irope
|
self.use_irope = use_irope
|
||||||
|
|
||||||
assert self.num_heads % self.num_kv_heads == 0
|
|
||||||
self.num_queries_per_kv = self.num_heads // self.num_kv_heads
|
self.num_queries_per_kv = self.num_heads // self.num_kv_heads
|
||||||
|
|
||||||
support_head_sizes = TritonAttentionBackend.get_supported_head_sizes()
|
support_head_sizes = TritonAttentionBackend.get_supported_head_sizes()
|
||||||
|
|||||||
Loading…
x
Reference in New Issue
Block a user