[Model][QwenVL] Optimize Qwen2_5_VisionAttention q,k preparation (#28769)

Signed-off-by: Lukas Geiger <lukas.geiger94@gmail.com>
Co-authored-by: Isotr0py <mozf@mail2.sysu.edu.cn>
This commit is contained in:
Lukas Geiger 2025-11-16 17:37:15 +00:00 committed by GitHub
parent ac1daf3233
commit 5a87076d6e
No known key found for this signature in database
GPG Key ID: B5690EEEBB952194
2 changed files with 25 additions and 27 deletions

View File

@ -39,8 +39,8 @@ from vllm.model_executor.models.interfaces import (
)
from vllm.model_executor.models.module_mapping import MultiModelKeys
from vllm.model_executor.models.qwen2 import Qwen2ForCausalLM
from vllm.model_executor.models.qwen2_5_vl import Qwen2_5_VisionAttention
from vllm.model_executor.models.qwen2_vl import (
Qwen2VisionAttention,
Qwen2VLDummyInputsBuilder,
Qwen2VLMultiModalProcessor,
Qwen2VLProcessingInfo,
@ -328,7 +328,7 @@ class DotsVisionAttention(nn.Module):
# [S, C] -> [S, B=1, C]
x = hidden_states.unsqueeze(1)
x, _ = self.qkv(x)
q, k, v = Qwen2_5_VisionAttention.split_qkv(self, x)
q, k, v = Qwen2VisionAttention.split_qkv(self, x)
bs = q.shape[1]
# [S,B,H,D] -> [B,S,H,D]
q = q.permute(1, 0, 2, 3).contiguous()

View File

@ -359,23 +359,6 @@ class Qwen2_5_VisionAttention(nn.Module):
AttentionBackendEnum.ROCM_AITER_FA,
}
def split_qkv(self, qkv: torch.Tensor) -> tuple[torch.Tensor, ...]:
# [s, b, 3 * head * head_dim]
seq_len, bs, _ = qkv.shape
# [s, b, 3 * head * head_dim] -> 3 * [s, b, head * head_dim]
q, k, v = qkv.chunk(3, dim=2)
# 3 * [s, b, head * head_dim] -> 3 * [s, b, head, head_dim]
new_shape = (
seq_len,
bs,
self.num_attention_heads_per_partition,
self.hidden_size_per_attention_head,
)
q, k, v = (x.view(*new_shape) for x in (q, k, v))
return q, k, v
def forward(
self,
x: torch.Tensor,
@ -386,17 +369,32 @@ class Qwen2_5_VisionAttention(nn.Module):
) -> torch.Tensor:
# [s, b, c] --> [s, b, head * 3 * head_dim]
x, _ = self.qkv(x)
seq_len, batch_size, _ = x.shape
# [s, b, 3 * head * head_dim] -> 3 * [s, b, head, head_dim]
q, k, v = self.split_qkv(x)
batch_size = q.shape[1]
qkv = einops.rearrange(
x,
"s b (three head head_dim) -> b s three head head_dim",
three=3,
head=self.num_attention_heads_per_partition,
)
q, k, v = (einops.rearrange(x, "s b ... -> b s ...") for x in (q, k, v))
if rotary_pos_emb is not None:
# [2 * b, s, heads, head_dim]
qk_concat = torch.cat([q, k], dim=0)
qk_rotated = apply_rotary_pos_emb_vision(qk_concat, rotary_pos_emb)
q, k = torch.chunk(qk_rotated, 2, dim=0)
qk, v = qkv[:, :, :2], qkv[:, :, 2]
qk_reshaped = einops.rearrange(
qk, "b s two head head_dim -> (two b) s head head_dim", two=2
)
qk_rotated = apply_rotary_pos_emb_vision(qk_reshaped, rotary_pos_emb)
qk_rotated = qk_rotated.view(
2,
batch_size,
seq_len,
self.num_attention_heads_per_partition,
self.hidden_size_per_attention_head,
)
q, k = qk_rotated.unbind(dim=0)
else:
q, k, v = qkv.unbind(dim=2)
if self.is_flash_attn_backend:
context_layer = vit_flash_attn_wrapper(