mirror of
https://git.datalinker.icu/vllm-project/vllm.git
synced 2025-12-15 00:15:01 +08:00
[Model][QwenVL] Optimize Qwen2_5_VisionAttention q,k preparation (#28769)
Signed-off-by: Lukas Geiger <lukas.geiger94@gmail.com> Co-authored-by: Isotr0py <mozf@mail2.sysu.edu.cn>
This commit is contained in:
parent
ac1daf3233
commit
5a87076d6e
@ -39,8 +39,8 @@ from vllm.model_executor.models.interfaces import (
|
|||||||
)
|
)
|
||||||
from vllm.model_executor.models.module_mapping import MultiModelKeys
|
from vllm.model_executor.models.module_mapping import MultiModelKeys
|
||||||
from vllm.model_executor.models.qwen2 import Qwen2ForCausalLM
|
from vllm.model_executor.models.qwen2 import Qwen2ForCausalLM
|
||||||
from vllm.model_executor.models.qwen2_5_vl import Qwen2_5_VisionAttention
|
|
||||||
from vllm.model_executor.models.qwen2_vl import (
|
from vllm.model_executor.models.qwen2_vl import (
|
||||||
|
Qwen2VisionAttention,
|
||||||
Qwen2VLDummyInputsBuilder,
|
Qwen2VLDummyInputsBuilder,
|
||||||
Qwen2VLMultiModalProcessor,
|
Qwen2VLMultiModalProcessor,
|
||||||
Qwen2VLProcessingInfo,
|
Qwen2VLProcessingInfo,
|
||||||
@ -328,7 +328,7 @@ class DotsVisionAttention(nn.Module):
|
|||||||
# [S, C] -> [S, B=1, C]
|
# [S, C] -> [S, B=1, C]
|
||||||
x = hidden_states.unsqueeze(1)
|
x = hidden_states.unsqueeze(1)
|
||||||
x, _ = self.qkv(x)
|
x, _ = self.qkv(x)
|
||||||
q, k, v = Qwen2_5_VisionAttention.split_qkv(self, x)
|
q, k, v = Qwen2VisionAttention.split_qkv(self, x)
|
||||||
bs = q.shape[1]
|
bs = q.shape[1]
|
||||||
# [S,B,H,D] -> [B,S,H,D]
|
# [S,B,H,D] -> [B,S,H,D]
|
||||||
q = q.permute(1, 0, 2, 3).contiguous()
|
q = q.permute(1, 0, 2, 3).contiguous()
|
||||||
|
|||||||
@ -359,23 +359,6 @@ class Qwen2_5_VisionAttention(nn.Module):
|
|||||||
AttentionBackendEnum.ROCM_AITER_FA,
|
AttentionBackendEnum.ROCM_AITER_FA,
|
||||||
}
|
}
|
||||||
|
|
||||||
def split_qkv(self, qkv: torch.Tensor) -> tuple[torch.Tensor, ...]:
|
|
||||||
# [s, b, 3 * head * head_dim]
|
|
||||||
seq_len, bs, _ = qkv.shape
|
|
||||||
|
|
||||||
# [s, b, 3 * head * head_dim] -> 3 * [s, b, head * head_dim]
|
|
||||||
q, k, v = qkv.chunk(3, dim=2)
|
|
||||||
|
|
||||||
# 3 * [s, b, head * head_dim] -> 3 * [s, b, head, head_dim]
|
|
||||||
new_shape = (
|
|
||||||
seq_len,
|
|
||||||
bs,
|
|
||||||
self.num_attention_heads_per_partition,
|
|
||||||
self.hidden_size_per_attention_head,
|
|
||||||
)
|
|
||||||
q, k, v = (x.view(*new_shape) for x in (q, k, v))
|
|
||||||
return q, k, v
|
|
||||||
|
|
||||||
def forward(
|
def forward(
|
||||||
self,
|
self,
|
||||||
x: torch.Tensor,
|
x: torch.Tensor,
|
||||||
@ -386,17 +369,32 @@ class Qwen2_5_VisionAttention(nn.Module):
|
|||||||
) -> torch.Tensor:
|
) -> torch.Tensor:
|
||||||
# [s, b, c] --> [s, b, head * 3 * head_dim]
|
# [s, b, c] --> [s, b, head * 3 * head_dim]
|
||||||
x, _ = self.qkv(x)
|
x, _ = self.qkv(x)
|
||||||
|
seq_len, batch_size, _ = x.shape
|
||||||
|
|
||||||
# [s, b, 3 * head * head_dim] -> 3 * [s, b, head, head_dim]
|
qkv = einops.rearrange(
|
||||||
q, k, v = self.split_qkv(x)
|
x,
|
||||||
batch_size = q.shape[1]
|
"s b (three head head_dim) -> b s three head head_dim",
|
||||||
|
three=3,
|
||||||
|
head=self.num_attention_heads_per_partition,
|
||||||
|
)
|
||||||
|
|
||||||
q, k, v = (einops.rearrange(x, "s b ... -> b s ...") for x in (q, k, v))
|
|
||||||
if rotary_pos_emb is not None:
|
if rotary_pos_emb is not None:
|
||||||
# [2 * b, s, heads, head_dim]
|
qk, v = qkv[:, :, :2], qkv[:, :, 2]
|
||||||
qk_concat = torch.cat([q, k], dim=0)
|
|
||||||
qk_rotated = apply_rotary_pos_emb_vision(qk_concat, rotary_pos_emb)
|
qk_reshaped = einops.rearrange(
|
||||||
q, k = torch.chunk(qk_rotated, 2, dim=0)
|
qk, "b s two head head_dim -> (two b) s head head_dim", two=2
|
||||||
|
)
|
||||||
|
qk_rotated = apply_rotary_pos_emb_vision(qk_reshaped, rotary_pos_emb)
|
||||||
|
qk_rotated = qk_rotated.view(
|
||||||
|
2,
|
||||||
|
batch_size,
|
||||||
|
seq_len,
|
||||||
|
self.num_attention_heads_per_partition,
|
||||||
|
self.hidden_size_per_attention_head,
|
||||||
|
)
|
||||||
|
q, k = qk_rotated.unbind(dim=0)
|
||||||
|
else:
|
||||||
|
q, k, v = qkv.unbind(dim=2)
|
||||||
|
|
||||||
if self.is_flash_attn_backend:
|
if self.is_flash_attn_backend:
|
||||||
context_layer = vit_flash_attn_wrapper(
|
context_layer = vit_flash_attn_wrapper(
|
||||||
|
|||||||
Loading…
x
Reference in New Issue
Block a user