mirror of
https://git.datalinker.icu/vllm-project/vllm.git
synced 2025-12-18 05:45:01 +08:00
[Bugfix/CI] Fix broken kernels/test_mha.py (#12450)
This commit is contained in:
parent
aa2cd2c43d
commit
72f4880425
@ -26,7 +26,7 @@ def clear_cache():
|
|||||||
@pytest.mark.parametrize("device", ["cpu", "hip", "cuda"])
|
@pytest.mark.parametrize("device", ["cpu", "hip", "cuda"])
|
||||||
def test_mha_attn_platform(device: str):
|
def test_mha_attn_platform(device: str):
|
||||||
"""
|
"""
|
||||||
Test that the attention selector between different platform and device.
|
Test the attention selector between different platform and device.
|
||||||
"""
|
"""
|
||||||
torch.set_default_dtype(torch.float16)
|
torch.set_default_dtype(torch.float16)
|
||||||
|
|
||||||
@ -41,7 +41,7 @@ def test_mha_attn_platform(device: str):
|
|||||||
else:
|
else:
|
||||||
with patch("vllm.attention.selector.current_platform", CudaPlatform()):
|
with patch("vllm.attention.selector.current_platform", CudaPlatform()):
|
||||||
attn = MultiHeadAttention(16, 64, scale=1)
|
attn = MultiHeadAttention(16, 64, scale=1)
|
||||||
assert attn.attn_backend == _Backend.FLASH_ATTN
|
assert attn.attn_backend == _Backend.XFORMERS
|
||||||
|
|
||||||
with patch("vllm.attention.selector.current_platform", CudaPlatform()):
|
with patch("vllm.attention.selector.current_platform", CudaPlatform()):
|
||||||
attn = MultiHeadAttention(16, 72, scale=1)
|
attn = MultiHeadAttention(16, 72, scale=1)
|
||||||
|
|||||||
@ -210,6 +210,9 @@ class MultiHeadAttention(nn.Module):
|
|||||||
self.scale = scale
|
self.scale = scale
|
||||||
self.num_kv_heads = num_heads if num_kv_heads is None else num_kv_heads
|
self.num_kv_heads = num_heads if num_kv_heads is None else num_kv_heads
|
||||||
|
|
||||||
|
assert self.num_heads % self.num_kv_heads == 0
|
||||||
|
self.num_queries_per_kv = self.num_heads // self.num_kv_heads
|
||||||
|
|
||||||
dtype = torch.get_default_dtype()
|
dtype = torch.get_default_dtype()
|
||||||
attn_backend = get_attn_backend(head_size,
|
attn_backend = get_attn_backend(head_size,
|
||||||
dtype,
|
dtype,
|
||||||
@ -240,6 +243,11 @@ class MultiHeadAttention(nn.Module):
|
|||||||
key = key.view(bsz, kv_len, self.num_kv_heads, self.head_size)
|
key = key.view(bsz, kv_len, self.num_kv_heads, self.head_size)
|
||||||
value = value.view(bsz, kv_len, self.num_kv_heads, self.head_size)
|
value = value.view(bsz, kv_len, self.num_kv_heads, self.head_size)
|
||||||
|
|
||||||
|
if (num_repeat := self.num_queries_per_kv) > 1:
|
||||||
|
# Handle MQA and GQA
|
||||||
|
key = torch.repeat_interleave(key, num_repeat, dim=2)
|
||||||
|
value = torch.repeat_interleave(value, num_repeat, dim=2)
|
||||||
|
|
||||||
if self.attn_backend == _Backend.XFORMERS:
|
if self.attn_backend == _Backend.XFORMERS:
|
||||||
from xformers import ops as xops
|
from xformers import ops as xops
|
||||||
|
|
||||||
|
|||||||
Loading…
x
Reference in New Issue
Block a user