diff --git a/vllm/model_executor/models/dots_ocr.py b/vllm/model_executor/models/dots_ocr.py index 405af8f8be42..f46caaa095c6 100644 --- a/vllm/model_executor/models/dots_ocr.py +++ b/vllm/model_executor/models/dots_ocr.py @@ -39,8 +39,8 @@ from vllm.model_executor.models.interfaces import ( ) from vllm.model_executor.models.module_mapping import MultiModelKeys from vllm.model_executor.models.qwen2 import Qwen2ForCausalLM -from vllm.model_executor.models.qwen2_5_vl import Qwen2_5_VisionAttention from vllm.model_executor.models.qwen2_vl import ( + Qwen2VisionAttention, Qwen2VLDummyInputsBuilder, Qwen2VLMultiModalProcessor, Qwen2VLProcessingInfo, @@ -328,7 +328,7 @@ class DotsVisionAttention(nn.Module): # [S, C] -> [S, B=1, C] x = hidden_states.unsqueeze(1) x, _ = self.qkv(x) - q, k, v = Qwen2_5_VisionAttention.split_qkv(self, x) + q, k, v = Qwen2VisionAttention.split_qkv(self, x) bs = q.shape[1] # [S,B,H,D] -> [B,S,H,D] q = q.permute(1, 0, 2, 3).contiguous() diff --git a/vllm/model_executor/models/qwen2_5_vl.py b/vllm/model_executor/models/qwen2_5_vl.py index 7617929e93ac..897dd7ef29f1 100644 --- a/vllm/model_executor/models/qwen2_5_vl.py +++ b/vllm/model_executor/models/qwen2_5_vl.py @@ -359,23 +359,6 @@ class Qwen2_5_VisionAttention(nn.Module): AttentionBackendEnum.ROCM_AITER_FA, } - def split_qkv(self, qkv: torch.Tensor) -> tuple[torch.Tensor, ...]: - # [s, b, 3 * head * head_dim] - seq_len, bs, _ = qkv.shape - - # [s, b, 3 * head * head_dim] -> 3 * [s, b, head * head_dim] - q, k, v = qkv.chunk(3, dim=2) - - # 3 * [s, b, head * head_dim] -> 3 * [s, b, head, head_dim] - new_shape = ( - seq_len, - bs, - self.num_attention_heads_per_partition, - self.hidden_size_per_attention_head, - ) - q, k, v = (x.view(*new_shape) for x in (q, k, v)) - return q, k, v - def forward( self, x: torch.Tensor, @@ -386,17 +369,32 @@ class Qwen2_5_VisionAttention(nn.Module): ) -> torch.Tensor: # [s, b, c] --> [s, b, head * 3 * head_dim] x, _ = self.qkv(x) + seq_len, batch_size, _ = x.shape - # [s, b, 3 * head * head_dim] -> 3 * [s, b, head, head_dim] - q, k, v = self.split_qkv(x) - batch_size = q.shape[1] + qkv = einops.rearrange( + x, + "s b (three head head_dim) -> b s three head head_dim", + three=3, + head=self.num_attention_heads_per_partition, + ) - q, k, v = (einops.rearrange(x, "s b ... -> b s ...") for x in (q, k, v)) if rotary_pos_emb is not None: - # [2 * b, s, heads, head_dim] - qk_concat = torch.cat([q, k], dim=0) - qk_rotated = apply_rotary_pos_emb_vision(qk_concat, rotary_pos_emb) - q, k = torch.chunk(qk_rotated, 2, dim=0) + qk, v = qkv[:, :, :2], qkv[:, :, 2] + + qk_reshaped = einops.rearrange( + qk, "b s two head head_dim -> (two b) s head head_dim", two=2 + ) + qk_rotated = apply_rotary_pos_emb_vision(qk_reshaped, rotary_pos_emb) + qk_rotated = qk_rotated.view( + 2, + batch_size, + seq_len, + self.num_attention_heads_per_partition, + self.hidden_size_per_attention_head, + ) + q, k = qk_rotated.unbind(dim=0) + else: + q, k, v = qkv.unbind(dim=2) if self.is_flash_attn_backend: context_layer = vit_flash_attn_wrapper(