mirror of
https://git.datalinker.icu/vllm-project/vllm.git
synced 2025-12-17 15:15:42 +08:00
[Misc][Bugfix] FA3 support to ViT MHA layer (#12435)
Signed-off-by: Roger Wang <ywang@roblox.com> Signed-off-by: Isotr0py <2037008807@qq.com> Co-authored-by: Isotr0py <2037008807@qq.com>
This commit is contained in:
parent
324960a95c
commit
2a0309a646
@ -251,9 +251,28 @@ class MultiHeadAttention(nn.Module):
|
|||||||
_Backend.FLASH_ATTN,
|
_Backend.FLASH_ATTN,
|
||||||
_Backend.FLASH_ATTN_VLLM_V1,
|
_Backend.FLASH_ATTN_VLLM_V1,
|
||||||
}:
|
}:
|
||||||
from vllm.vllm_flash_attn import flash_attn_func
|
from vllm.vllm_flash_attn import flash_attn_varlen_func
|
||||||
|
|
||||||
out = flash_attn_func(query, key, value, softmax_scale=self.scale)
|
cu_seqlens_q = torch.arange(0, (bsz + 1) * q_len,
|
||||||
|
step=q_len,
|
||||||
|
dtype=torch.int32,
|
||||||
|
device=query.device)
|
||||||
|
cu_seqlens_k = torch.arange(0, (bsz + 1) * kv_len,
|
||||||
|
step=kv_len,
|
||||||
|
dtype=torch.int32,
|
||||||
|
device=key.device)
|
||||||
|
|
||||||
|
out = flash_attn_varlen_func(
|
||||||
|
query.flatten(0, 1),
|
||||||
|
key.flatten(0, 1),
|
||||||
|
value.flatten(0, 1),
|
||||||
|
cu_seqlens_q=cu_seqlens_q,
|
||||||
|
cu_seqlens_k=cu_seqlens_k,
|
||||||
|
max_seqlen_q=q_len,
|
||||||
|
max_seqlen_k=kv_len,
|
||||||
|
softmax_scale=self.scale,
|
||||||
|
)
|
||||||
|
out = out.reshape(bsz, q_len, -1)
|
||||||
elif self.attn_backend == _Backend.XFORMERS:
|
elif self.attn_backend == _Backend.XFORMERS:
|
||||||
from xformers import ops as xops
|
from xformers import ops as xops
|
||||||
|
|
||||||
|
|||||||
Loading…
x
Reference in New Issue
Block a user