[Multi Modal][Performance] Fused Q,K's apply_rope into one (#24511)

Signed-off-by: wwl2755 <wangwenlong2755@gmail.com>
Co-authored-by: Isotr0py <mozf@mail2.sysu.edu.cn>
This commit is contained in:
Wenlong Wang 2025-09-14 01:10:21 -07:00 committed by GitHub
parent 3e903b6cb4
commit cc3173ae98
No known key found for this signature in database
GPG Key ID: B5690EEEBB952194

View File

@ -363,8 +363,10 @@ class Qwen2_5_VisionAttention(nn.Module):
q, k, v = (rearrange(x, "s b ... -> b s ...").contiguous()
for x in (q, k, v))
if rotary_pos_emb is not None:
q = apply_rotary_pos_emb_vision(q, rotary_pos_emb)
k = apply_rotary_pos_emb_vision(k, rotary_pos_emb)
# [2 * b, s, heads, head_dim]
qk_concat = torch.cat([q, k], dim=0)
qk_rotated = apply_rotary_pos_emb_vision(qk_concat, rotary_pos_emb)
q, k = torch.chunk(qk_rotated, 2, dim=0)
if self.is_flash_attn_backend:
if self.attn_backend == _Backend.ROCM_AITER_FA:
@ -388,8 +390,8 @@ class Qwen2_5_VisionAttention(nn.Module):
causal=False)
context_layer = rearrange(output,
"(b s) ... -> b s ...",
b=batch_size)
"(b s) h d -> s b (h d)",
b=batch_size).contiguous()
elif self.attn_backend == _Backend.TORCH_SDPA:
# Execute attention entry by entry for speed & less VRAM.
outputs = []
@ -408,6 +410,8 @@ class Qwen2_5_VisionAttention(nn.Module):
output_i = rearrange(output_i, "b h s d -> b s h d ")
outputs.append(output_i)
context_layer = torch.cat(outputs, dim=1)
context_layer = rearrange(context_layer,
"b s h d -> s b (h d)").contiguous()
elif self.attn_backend == _Backend.XFORMERS:
from xformers import ops as xops
from xformers.ops.fmha.attn_bias import BlockDiagonalMask
@ -418,8 +422,8 @@ class Qwen2_5_VisionAttention(nn.Module):
context_layer = xops.memory_efficient_attention_forward(
q, k, v, attn_bias=attn_bias, p=0, scale=None)
context_layer = rearrange(context_layer,
"b s h d -> s b (h d)").contiguous()
context_layer = rearrange(context_layer,
"b s h d -> s b (h d)").contiguous()
output, _ = self.proj(context_layer)
return output