mirror of
https://git.datalinker.icu/vllm-project/vllm.git
synced 2026-04-05 01:27:10 +08:00
parent
0ee349b553
commit
a5255270c3
@ -210,9 +210,6 @@ class MultiHeadAttention(nn.Module):
|
||||
self.scale = scale
|
||||
self.num_kv_heads = num_heads if num_kv_heads is None else num_kv_heads
|
||||
|
||||
assert self.num_heads % self.num_kv_heads == 0
|
||||
self.num_queries_per_kv = self.num_heads // self.num_kv_heads
|
||||
|
||||
dtype = torch.get_default_dtype()
|
||||
attn_backend = get_attn_backend(head_size,
|
||||
dtype,
|
||||
@ -220,12 +217,12 @@ class MultiHeadAttention(nn.Module):
|
||||
block_size=16,
|
||||
is_attention_free=False)
|
||||
backend = backend_name_to_enum(attn_backend.get_name())
|
||||
if backend in {_Backend.FLASH_ATTN, _Backend.FLASH_ATTN_VLLM_V1}:
|
||||
backend = _Backend.XFORMERS
|
||||
|
||||
self.attn_backend = backend if backend in {
|
||||
_Backend.TORCH_SDPA,
|
||||
_Backend.XFORMERS,
|
||||
_Backend.FLASH_ATTN,
|
||||
_Backend.FLASH_ATTN_VLLM_V1,
|
||||
} else _Backend.TORCH_SDPA
|
||||
|
||||
def forward(
|
||||
@ -235,6 +232,7 @@ class MultiHeadAttention(nn.Module):
|
||||
value: torch.Tensor,
|
||||
) -> torch.Tensor:
|
||||
"""Input shape: batch_size x seq_len x hidden_size"""
|
||||
# TODO(Isotr0py): Use existing backend implementations and support FA3
|
||||
bsz, q_len, _ = query.size()
|
||||
kv_len = key.size(1)
|
||||
|
||||
@ -242,38 +240,7 @@ class MultiHeadAttention(nn.Module):
|
||||
key = key.view(bsz, kv_len, self.num_kv_heads, self.head_size)
|
||||
value = value.view(bsz, kv_len, self.num_kv_heads, self.head_size)
|
||||
|
||||
if (num_repeat := self.num_queries_per_kv) > 1:
|
||||
# Handle MQA and GQA
|
||||
key = torch.repeat_interleave(key, num_repeat, dim=2)
|
||||
value = torch.repeat_interleave(value, num_repeat, dim=2)
|
||||
|
||||
if self.attn_backend in {
|
||||
_Backend.FLASH_ATTN,
|
||||
_Backend.FLASH_ATTN_VLLM_V1,
|
||||
}:
|
||||
from vllm.vllm_flash_attn import flash_attn_varlen_func
|
||||
|
||||
cu_seqlens_q = torch.arange(0, (bsz + 1) * q_len,
|
||||
step=q_len,
|
||||
dtype=torch.int32,
|
||||
device=query.device)
|
||||
cu_seqlens_k = torch.arange(0, (bsz + 1) * kv_len,
|
||||
step=kv_len,
|
||||
dtype=torch.int32,
|
||||
device=key.device)
|
||||
|
||||
out = flash_attn_varlen_func(
|
||||
query.flatten(0, 1),
|
||||
key.flatten(0, 1),
|
||||
value.flatten(0, 1),
|
||||
cu_seqlens_q=cu_seqlens_q,
|
||||
cu_seqlens_k=cu_seqlens_k,
|
||||
max_seqlen_q=q_len,
|
||||
max_seqlen_k=kv_len,
|
||||
softmax_scale=self.scale,
|
||||
)
|
||||
out = out.reshape(bsz, q_len, -1)
|
||||
elif self.attn_backend == _Backend.XFORMERS:
|
||||
if self.attn_backend == _Backend.XFORMERS:
|
||||
from xformers import ops as xops
|
||||
|
||||
out = xops.memory_efficient_attention_forward(query,
|
||||
|
||||
Loading…
x
Reference in New Issue
Block a user