[Bugfix/CI] Fix broken kernels/test_mha.py (#12450)

This commit is contained in:
Tyler Michael Smith 2025-01-26 13:39:03 -05:00 committed by GitHub
parent aa2cd2c43d
commit 72f4880425
No known key found for this signature in database
GPG Key ID: B5690EEEBB952194
2 changed files with 10 additions and 2 deletions

View File

@ -26,7 +26,7 @@ def clear_cache():
@pytest.mark.parametrize("device", ["cpu", "hip", "cuda"]) @pytest.mark.parametrize("device", ["cpu", "hip", "cuda"])
def test_mha_attn_platform(device: str): def test_mha_attn_platform(device: str):
""" """
Test that the attention selector between different platform and device. Test the attention selector between different platform and device.
""" """
torch.set_default_dtype(torch.float16) torch.set_default_dtype(torch.float16)
@ -41,7 +41,7 @@ def test_mha_attn_platform(device: str):
else: else:
with patch("vllm.attention.selector.current_platform", CudaPlatform()): with patch("vllm.attention.selector.current_platform", CudaPlatform()):
attn = MultiHeadAttention(16, 64, scale=1) attn = MultiHeadAttention(16, 64, scale=1)
assert attn.attn_backend == _Backend.FLASH_ATTN assert attn.attn_backend == _Backend.XFORMERS
with patch("vllm.attention.selector.current_platform", CudaPlatform()): with patch("vllm.attention.selector.current_platform", CudaPlatform()):
attn = MultiHeadAttention(16, 72, scale=1) attn = MultiHeadAttention(16, 72, scale=1)

View File

@ -210,6 +210,9 @@ class MultiHeadAttention(nn.Module):
self.scale = scale self.scale = scale
self.num_kv_heads = num_heads if num_kv_heads is None else num_kv_heads self.num_kv_heads = num_heads if num_kv_heads is None else num_kv_heads
assert self.num_heads % self.num_kv_heads == 0
self.num_queries_per_kv = self.num_heads // self.num_kv_heads
dtype = torch.get_default_dtype() dtype = torch.get_default_dtype()
attn_backend = get_attn_backend(head_size, attn_backend = get_attn_backend(head_size,
dtype, dtype,
@ -240,6 +243,11 @@ class MultiHeadAttention(nn.Module):
key = key.view(bsz, kv_len, self.num_kv_heads, self.head_size) key = key.view(bsz, kv_len, self.num_kv_heads, self.head_size)
value = value.view(bsz, kv_len, self.num_kv_heads, self.head_size) value = value.view(bsz, kv_len, self.num_kv_heads, self.head_size)
if (num_repeat := self.num_queries_per_kv) > 1:
# Handle MQA and GQA
key = torch.repeat_interleave(key, num_repeat, dim=2)
value = torch.repeat_interleave(value, num_repeat, dim=2)
if self.attn_backend == _Backend.XFORMERS: if self.attn_backend == _Backend.XFORMERS:
from xformers import ops as xops from xformers import ops as xops