mirror of
https://git.datalinker.icu/vllm-project/vllm.git
synced 2025-12-14 19:05:35 +08:00
[Multi Modal][Performance] Fused Q,K's apply_rope in more models (#25005)
Signed-off-by: wwl2755 <wangwenlong2755@gmail.com> Co-authored-by: Roger Wang <hey@rogerw.io>
This commit is contained in:
parent
1cd885bd54
commit
035fd2bd2c
@ -234,8 +234,9 @@ class Ernie4_5_VisionAttention(nn.Module):
|
|||||||
q, k, v = (rearrange(x, "s b ... -> b s ...").contiguous()
|
q, k, v = (rearrange(x, "s b ... -> b s ...").contiguous()
|
||||||
for x in (q, k, v))
|
for x in (q, k, v))
|
||||||
if rotary_pos_emb is not None:
|
if rotary_pos_emb is not None:
|
||||||
q = apply_rotary_pos_emb_vision(q, rotary_pos_emb)
|
qk_concat = torch.cat([q, k], dim=0)
|
||||||
k = apply_rotary_pos_emb_vision(k, rotary_pos_emb)
|
qk_rotated = apply_rotary_pos_emb_vision(qk_concat, rotary_pos_emb)
|
||||||
|
q, k = torch.chunk(qk_rotated, 2, dim=0)
|
||||||
|
|
||||||
if self.is_flash_attn_backend:
|
if self.is_flash_attn_backend:
|
||||||
# from vllm_flash_attn.flash_attn_interface import (
|
# from vllm_flash_attn.flash_attn_interface import (
|
||||||
@ -261,8 +262,8 @@ class Ernie4_5_VisionAttention(nn.Module):
|
|||||||
causal=False)
|
causal=False)
|
||||||
|
|
||||||
context_layer = rearrange(output,
|
context_layer = rearrange(output,
|
||||||
"(b s) ... -> b s ...",
|
"(b s) h d -> s b (h d)",
|
||||||
b=batch_size)
|
b=batch_size).contiguous()
|
||||||
elif self.attn_backend == _Backend.TORCH_SDPA:
|
elif self.attn_backend == _Backend.TORCH_SDPA:
|
||||||
# Execute attention entry by entry for speed & less VRAM.
|
# Execute attention entry by entry for speed & less VRAM.
|
||||||
outputs = []
|
outputs = []
|
||||||
@ -281,6 +282,8 @@ class Ernie4_5_VisionAttention(nn.Module):
|
|||||||
output_i = rearrange(output_i, "b h s d -> b s h d ")
|
output_i = rearrange(output_i, "b h s d -> b s h d ")
|
||||||
outputs.append(output_i)
|
outputs.append(output_i)
|
||||||
context_layer = torch.cat(outputs, dim=1)
|
context_layer = torch.cat(outputs, dim=1)
|
||||||
|
context_layer = rearrange(context_layer,
|
||||||
|
"b s h d -> s b (h d)").contiguous()
|
||||||
elif self.attn_backend == _Backend.XFORMERS:
|
elif self.attn_backend == _Backend.XFORMERS:
|
||||||
from xformers import ops as xops
|
from xformers import ops as xops
|
||||||
from xformers.ops.fmha.attn_bias import BlockDiagonalMask
|
from xformers.ops.fmha.attn_bias import BlockDiagonalMask
|
||||||
|
|||||||
@ -315,8 +315,10 @@ class Glm4vVisionAttention(nn.Module):
|
|||||||
q, k, v = (rearrange(x, "s b ... -> b s ...").contiguous()
|
q, k, v = (rearrange(x, "s b ... -> b s ...").contiguous()
|
||||||
for x in (q, k, v))
|
for x in (q, k, v))
|
||||||
if rotary_pos_emb is not None:
|
if rotary_pos_emb is not None:
|
||||||
q = apply_rotary_pos_emb_vision(q, rotary_pos_emb)
|
# [2 * b, s, heads, head_dim]
|
||||||
k = apply_rotary_pos_emb_vision(k, rotary_pos_emb)
|
qk_concat = torch.cat([q, k], dim=0)
|
||||||
|
qk_rotated = apply_rotary_pos_emb_vision(qk_concat, rotary_pos_emb)
|
||||||
|
q, k = torch.chunk(qk_rotated, 2, dim=0)
|
||||||
|
|
||||||
if self.attn_backend == _Backend.FLASH_ATTN:
|
if self.attn_backend == _Backend.FLASH_ATTN:
|
||||||
# from vllm_flash_attn.flash_attn_interface import (
|
# from vllm_flash_attn.flash_attn_interface import (
|
||||||
@ -341,8 +343,8 @@ class Glm4vVisionAttention(nn.Module):
|
|||||||
)
|
)
|
||||||
|
|
||||||
context_layer = rearrange(output,
|
context_layer = rearrange(output,
|
||||||
"(b s) ... -> b s ...",
|
"(b s) h d -> s b (h d)",
|
||||||
b=batch_size)
|
b=batch_size).contiguous()
|
||||||
elif self.attn_backend == _Backend.TORCH_SDPA:
|
elif self.attn_backend == _Backend.TORCH_SDPA:
|
||||||
# Execute attention entry by entry for speed & less VRAM.
|
# Execute attention entry by entry for speed & less VRAM.
|
||||||
outputs = []
|
outputs = []
|
||||||
@ -361,6 +363,8 @@ class Glm4vVisionAttention(nn.Module):
|
|||||||
output_i = rearrange(output_i, "b h s d -> b s h d ")
|
output_i = rearrange(output_i, "b h s d -> b s h d ")
|
||||||
outputs.append(output_i)
|
outputs.append(output_i)
|
||||||
context_layer = torch.cat(outputs, dim=1)
|
context_layer = torch.cat(outputs, dim=1)
|
||||||
|
context_layer = rearrange(context_layer,
|
||||||
|
"b s h d -> s b (h d)").contiguous()
|
||||||
elif self.attn_backend == _Backend.XFORMERS:
|
elif self.attn_backend == _Backend.XFORMERS:
|
||||||
from xformers import ops as xops
|
from xformers import ops as xops
|
||||||
from xformers.ops.fmha.attn_bias import BlockDiagonalMask
|
from xformers.ops.fmha.attn_bias import BlockDiagonalMask
|
||||||
@ -371,7 +375,6 @@ class Glm4vVisionAttention(nn.Module):
|
|||||||
|
|
||||||
context_layer = xops.memory_efficient_attention_forward(
|
context_layer = xops.memory_efficient_attention_forward(
|
||||||
q, k, v, attn_bias=attn_bias, p=0, scale=None)
|
q, k, v, attn_bias=attn_bias, p=0, scale=None)
|
||||||
|
|
||||||
context_layer = rearrange(context_layer,
|
context_layer = rearrange(context_layer,
|
||||||
"b s h d -> s b (h d)").contiguous()
|
"b s h d -> s b (h d)").contiguous()
|
||||||
|
|
||||||
|
|||||||
@ -377,8 +377,10 @@ class Qwen2VisionAttention(nn.Module):
|
|||||||
q, k, v = (rearrange(x, "s b ... -> b s ...").contiguous()
|
q, k, v = (rearrange(x, "s b ... -> b s ...").contiguous()
|
||||||
for x in (q, k, v))
|
for x in (q, k, v))
|
||||||
if rotary_pos_emb is not None:
|
if rotary_pos_emb is not None:
|
||||||
q = apply_rotary_pos_emb_vision(q, rotary_pos_emb)
|
# [2 * b, s, heads, head_dim]
|
||||||
k = apply_rotary_pos_emb_vision(k, rotary_pos_emb)
|
qk_concat = torch.cat([q, k], dim=0)
|
||||||
|
qk_rotated = apply_rotary_pos_emb_vision(qk_concat, rotary_pos_emb)
|
||||||
|
q, k = torch.chunk(qk_rotated, 2, dim=0)
|
||||||
|
|
||||||
if self.is_flash_attn_backend:
|
if self.is_flash_attn_backend:
|
||||||
if self.attn_backend == _Backend.ROCM_AITER_FA:
|
if self.attn_backend == _Backend.ROCM_AITER_FA:
|
||||||
@ -402,8 +404,8 @@ class Qwen2VisionAttention(nn.Module):
|
|||||||
causal=False)
|
causal=False)
|
||||||
|
|
||||||
context_layer = rearrange(output,
|
context_layer = rearrange(output,
|
||||||
"(b s) ... -> b s ...",
|
"(b s) h d -> s b (h d)",
|
||||||
b=batch_size)
|
b=batch_size).contiguous()
|
||||||
elif self.attn_backend == _Backend.TORCH_SDPA:
|
elif self.attn_backend == _Backend.TORCH_SDPA:
|
||||||
# Execute attention entry by entry for speed & less VRAM.
|
# Execute attention entry by entry for speed & less VRAM.
|
||||||
outputs = []
|
outputs = []
|
||||||
@ -422,6 +424,8 @@ class Qwen2VisionAttention(nn.Module):
|
|||||||
output_i = rearrange(output_i, "b h s d -> b s h d ")
|
output_i = rearrange(output_i, "b h s d -> b s h d ")
|
||||||
outputs.append(output_i)
|
outputs.append(output_i)
|
||||||
context_layer = torch.cat(outputs, dim=1)
|
context_layer = torch.cat(outputs, dim=1)
|
||||||
|
context_layer = rearrange(context_layer,
|
||||||
|
"b s h d -> s b (h d)").contiguous()
|
||||||
elif self.attn_backend == _Backend.XFORMERS:
|
elif self.attn_backend == _Backend.XFORMERS:
|
||||||
from xformers import ops as xops
|
from xformers import ops as xops
|
||||||
from xformers.ops.fmha.attn_bias import BlockDiagonalMask
|
from xformers.ops.fmha.attn_bias import BlockDiagonalMask
|
||||||
|
|||||||
Loading…
x
Reference in New Issue
Block a user