mirror of
https://git.datalinker.icu/vllm-project/vllm.git
synced 2025-12-16 21:35:50 +08:00
[Misc] Add FA2 support to ViT MHA layer (#12355)
Signed-off-by: Isotr0py <2037008807@qq.com>
This commit is contained in:
parent
bf21481dde
commit
f1fc0510df
126
tests/kernels/test_mha_attn.py
Normal file
126
tests/kernels/test_mha_attn.py
Normal file
@ -0,0 +1,126 @@
|
|||||||
|
"""
|
||||||
|
Test:
|
||||||
|
|
||||||
|
* Tests for MultiHeadAttention layer
|
||||||
|
"""
|
||||||
|
from unittest.mock import patch
|
||||||
|
|
||||||
|
import pytest
|
||||||
|
import torch
|
||||||
|
|
||||||
|
from vllm.attention.layer import MultiHeadAttention
|
||||||
|
from vllm.attention.selector import _Backend, _cached_get_attn_backend
|
||||||
|
from vllm.platforms import current_platform
|
||||||
|
from vllm.platforms.cpu import CpuPlatform
|
||||||
|
from vllm.platforms.cuda import CudaPlatform
|
||||||
|
from vllm.platforms.rocm import RocmPlatform
|
||||||
|
|
||||||
|
|
||||||
|
@pytest.fixture(autouse=True)
|
||||||
|
def clear_cache():
|
||||||
|
"""Clear lru cache to ensure each test case runs without caching.
|
||||||
|
"""
|
||||||
|
_cached_get_attn_backend.cache_clear()
|
||||||
|
|
||||||
|
|
||||||
|
@pytest.mark.parametrize("device", ["cpu", "hip", "cuda"])
|
||||||
|
def test_mha_attn_platform(device: str):
|
||||||
|
"""
|
||||||
|
Test that the attention selector between different platform and device.
|
||||||
|
"""
|
||||||
|
torch.set_default_dtype(torch.float16)
|
||||||
|
|
||||||
|
if device == "cpu":
|
||||||
|
with patch("vllm.attention.selector.current_platform", CpuPlatform()):
|
||||||
|
attn = MultiHeadAttention(16, 64, scale=1)
|
||||||
|
assert attn.attn_backend == _Backend.TORCH_SDPA
|
||||||
|
elif device == "hip":
|
||||||
|
with patch("vllm.attention.selector.current_platform", RocmPlatform()):
|
||||||
|
attn = MultiHeadAttention(16, 64, scale=1)
|
||||||
|
assert attn.attn_backend == _Backend.TORCH_SDPA
|
||||||
|
else:
|
||||||
|
with patch("vllm.attention.selector.current_platform", CudaPlatform()):
|
||||||
|
attn = MultiHeadAttention(16, 64, scale=1)
|
||||||
|
assert attn.attn_backend == _Backend.FLASH_ATTN
|
||||||
|
|
||||||
|
with patch("vllm.attention.selector.current_platform", CudaPlatform()):
|
||||||
|
attn = MultiHeadAttention(16, 72, scale=1)
|
||||||
|
assert attn.attn_backend == _Backend.XFORMERS
|
||||||
|
|
||||||
|
|
||||||
|
def ref_attention(
|
||||||
|
query: torch.Tensor,
|
||||||
|
key: torch.Tensor,
|
||||||
|
value: torch.Tensor,
|
||||||
|
scale: float,
|
||||||
|
) -> torch.Tensor:
|
||||||
|
"""
|
||||||
|
Native implementation of scaled dot product attention without mask:
|
||||||
|
- query, key, value: [batch_size, seq_len, num_heads, head_size]
|
||||||
|
- attn_mask: [batch_size, seq_len, seq_len]
|
||||||
|
"""
|
||||||
|
query, key, value = (x.transpose(1, 2) for x in (query, key, value))
|
||||||
|
attn_weights = scale * torch.matmul(query, key.transpose(2, 3))
|
||||||
|
attn_weights = torch.softmax(attn_weights, dim=-1).to(value.dtype)
|
||||||
|
out = torch.matmul(attn_weights, value).transpose(1, 2)
|
||||||
|
return out
|
||||||
|
|
||||||
|
|
||||||
|
BATCH_SIZES = [1, 16]
|
||||||
|
SEQ_LENS = [1]
|
||||||
|
NUM_HEADS = [1, 16]
|
||||||
|
NUM_KV_HEADS = [1]
|
||||||
|
HEAD_SIZES = [64, 80]
|
||||||
|
# flshattF and tritonflashattF supported: {torch.float16, torch.bfloat16}
|
||||||
|
DTYPES = [
|
||||||
|
torch.half, torch.bfloat16, torch.float
|
||||||
|
] if not current_platform.is_rocm() else [torch.half, torch.bfloat16]
|
||||||
|
CUDA_DEVICES = ["cuda"]
|
||||||
|
|
||||||
|
|
||||||
|
@pytest.mark.parametrize("batch_size", BATCH_SIZES)
|
||||||
|
@pytest.mark.parametrize("seq_len", SEQ_LENS)
|
||||||
|
@pytest.mark.parametrize("num_heads", NUM_HEADS)
|
||||||
|
@pytest.mark.parametrize("num_kv_heads", NUM_KV_HEADS)
|
||||||
|
@pytest.mark.parametrize("head_size", HEAD_SIZES)
|
||||||
|
@pytest.mark.parametrize("dtype", DTYPES)
|
||||||
|
@pytest.mark.parametrize("device", CUDA_DEVICES)
|
||||||
|
def test_mha_attn_forward(
|
||||||
|
batch_size: int,
|
||||||
|
seq_len: int,
|
||||||
|
num_heads: int,
|
||||||
|
num_kv_heads: int,
|
||||||
|
head_size: int,
|
||||||
|
dtype: torch.dtype,
|
||||||
|
device: str,
|
||||||
|
):
|
||||||
|
current_platform.seed_everything(0)
|
||||||
|
torch.set_default_device(device)
|
||||||
|
torch.set_default_dtype(dtype)
|
||||||
|
|
||||||
|
q = torch.randn(batch_size, seq_len, num_heads * head_size)
|
||||||
|
k = torch.randn(batch_size, seq_len, num_kv_heads * head_size)
|
||||||
|
v = torch.randn(batch_size, seq_len, num_kv_heads * head_size)
|
||||||
|
scale = 1.0 / head_size**0.5
|
||||||
|
attn = MultiHeadAttention(num_heads,
|
||||||
|
head_size,
|
||||||
|
scale=scale,
|
||||||
|
num_kv_heads=num_kv_heads)
|
||||||
|
output = attn(q, k, v)
|
||||||
|
|
||||||
|
assert num_heads % num_kv_heads == 0
|
||||||
|
num_queries_per_kv = num_heads // num_kv_heads
|
||||||
|
q = q.reshape(batch_size, seq_len, num_heads, head_size)
|
||||||
|
k = k.reshape(batch_size, seq_len, num_kv_heads, head_size)
|
||||||
|
v = v.reshape(batch_size, seq_len, num_kv_heads, head_size)
|
||||||
|
if num_queries_per_kv > 1:
|
||||||
|
k = torch.repeat_interleave(k, num_queries_per_kv, dim=2)
|
||||||
|
v = torch.repeat_interleave(v, num_queries_per_kv, dim=2)
|
||||||
|
|
||||||
|
ref_output = ref_attention(
|
||||||
|
q,
|
||||||
|
k,
|
||||||
|
v,
|
||||||
|
scale=scale,
|
||||||
|
).reshape(batch_size, seq_len, num_heads * head_size)
|
||||||
|
torch.testing.assert_close(output, ref_output)
|
||||||
@ -210,6 +210,9 @@ class MultiHeadAttention(nn.Module):
|
|||||||
self.scale = scale
|
self.scale = scale
|
||||||
self.num_kv_heads = num_heads if num_kv_heads is None else num_kv_heads
|
self.num_kv_heads = num_heads if num_kv_heads is None else num_kv_heads
|
||||||
|
|
||||||
|
assert self.num_heads % self.num_kv_heads == 0
|
||||||
|
self.num_queries_per_kv = self.num_heads // self.num_kv_heads
|
||||||
|
|
||||||
dtype = torch.get_default_dtype()
|
dtype = torch.get_default_dtype()
|
||||||
attn_backend = get_attn_backend(head_size,
|
attn_backend = get_attn_backend(head_size,
|
||||||
dtype,
|
dtype,
|
||||||
@ -217,11 +220,12 @@ class MultiHeadAttention(nn.Module):
|
|||||||
block_size=16,
|
block_size=16,
|
||||||
is_attention_free=False)
|
is_attention_free=False)
|
||||||
backend = backend_name_to_enum(attn_backend.get_name())
|
backend = backend_name_to_enum(attn_backend.get_name())
|
||||||
if backend in {_Backend.FLASH_ATTN, _Backend.FLASH_ATTN_VLLM_V1}:
|
|
||||||
backend = _Backend.XFORMERS
|
|
||||||
|
|
||||||
self.attn_backend = backend if backend in {
|
self.attn_backend = backend if backend in {
|
||||||
_Backend.TORCH_SDPA, _Backend.XFORMERS
|
_Backend.TORCH_SDPA,
|
||||||
|
_Backend.XFORMERS,
|
||||||
|
_Backend.FLASH_ATTN,
|
||||||
|
_Backend.FLASH_ATTN_VLLM_V1,
|
||||||
} else _Backend.TORCH_SDPA
|
} else _Backend.TORCH_SDPA
|
||||||
|
|
||||||
def forward(
|
def forward(
|
||||||
@ -231,7 +235,6 @@ class MultiHeadAttention(nn.Module):
|
|||||||
value: torch.Tensor,
|
value: torch.Tensor,
|
||||||
) -> torch.Tensor:
|
) -> torch.Tensor:
|
||||||
"""Input shape: batch_size x seq_len x hidden_size"""
|
"""Input shape: batch_size x seq_len x hidden_size"""
|
||||||
# TODO(Isotr0py): Use existing backend implementations and support FA2
|
|
||||||
bsz, q_len, _ = query.size()
|
bsz, q_len, _ = query.size()
|
||||||
kv_len = key.size(1)
|
kv_len = key.size(1)
|
||||||
|
|
||||||
@ -239,7 +242,19 @@ class MultiHeadAttention(nn.Module):
|
|||||||
key = key.view(bsz, kv_len, self.num_kv_heads, self.head_size)
|
key = key.view(bsz, kv_len, self.num_kv_heads, self.head_size)
|
||||||
value = value.view(bsz, kv_len, self.num_kv_heads, self.head_size)
|
value = value.view(bsz, kv_len, self.num_kv_heads, self.head_size)
|
||||||
|
|
||||||
if self.attn_backend == _Backend.XFORMERS:
|
if (num_repeat := self.num_queries_per_kv) > 1:
|
||||||
|
# Handle MQA and GQA
|
||||||
|
key = torch.repeat_interleave(key, num_repeat, dim=2)
|
||||||
|
value = torch.repeat_interleave(value, num_repeat, dim=2)
|
||||||
|
|
||||||
|
if self.attn_backend in {
|
||||||
|
_Backend.FLASH_ATTN,
|
||||||
|
_Backend.FLASH_ATTN_VLLM_V1,
|
||||||
|
}:
|
||||||
|
from vllm.vllm_flash_attn import flash_attn_func
|
||||||
|
|
||||||
|
out = flash_attn_func(query, key, value, softmax_scale=self.scale)
|
||||||
|
elif self.attn_backend == _Backend.XFORMERS:
|
||||||
from xformers import ops as xops
|
from xformers import ops as xops
|
||||||
|
|
||||||
out = xops.memory_efficient_attention_forward(query,
|
out = xops.memory_efficient_attention_forward(query,
|
||||||
|
|||||||
Loading…
x
Reference in New Issue
Block a user