diff --git a/vllm/model_executor/models/ernie45_vl.py b/vllm/model_executor/models/ernie45_vl.py index 3396c67f42b7b..0d4aced93ca1c 100644 --- a/vllm/model_executor/models/ernie45_vl.py +++ b/vllm/model_executor/models/ernie45_vl.py @@ -234,8 +234,9 @@ class Ernie4_5_VisionAttention(nn.Module): q, k, v = (rearrange(x, "s b ... -> b s ...").contiguous() for x in (q, k, v)) if rotary_pos_emb is not None: - q = apply_rotary_pos_emb_vision(q, rotary_pos_emb) - k = apply_rotary_pos_emb_vision(k, rotary_pos_emb) + qk_concat = torch.cat([q, k], dim=0) + qk_rotated = apply_rotary_pos_emb_vision(qk_concat, rotary_pos_emb) + q, k = torch.chunk(qk_rotated, 2, dim=0) if self.is_flash_attn_backend: # from vllm_flash_attn.flash_attn_interface import ( @@ -261,8 +262,8 @@ class Ernie4_5_VisionAttention(nn.Module): causal=False) context_layer = rearrange(output, - "(b s) ... -> b s ...", - b=batch_size) + "(b s) h d -> s b (h d)", + b=batch_size).contiguous() elif self.attn_backend == _Backend.TORCH_SDPA: # Execute attention entry by entry for speed & less VRAM. outputs = [] @@ -281,6 +282,8 @@ class Ernie4_5_VisionAttention(nn.Module): output_i = rearrange(output_i, "b h s d -> b s h d ") outputs.append(output_i) context_layer = torch.cat(outputs, dim=1) + context_layer = rearrange(context_layer, + "b s h d -> s b (h d)").contiguous() elif self.attn_backend == _Backend.XFORMERS: from xformers import ops as xops from xformers.ops.fmha.attn_bias import BlockDiagonalMask @@ -291,8 +294,8 @@ class Ernie4_5_VisionAttention(nn.Module): context_layer = xops.memory_efficient_attention_forward( q, k, v, attn_bias=attn_bias, p=0, scale=None) - context_layer = rearrange(context_layer, - "b s h d -> s b (h d)").contiguous() + context_layer = rearrange(context_layer, + "b s h d -> s b (h d)").contiguous() output, _ = self.proj(context_layer) return output diff --git a/vllm/model_executor/models/glm4_1v.py b/vllm/model_executor/models/glm4_1v.py index cbf327ce02b6b..308b0cb602bc9 100644 --- a/vllm/model_executor/models/glm4_1v.py +++ b/vllm/model_executor/models/glm4_1v.py @@ -315,8 +315,10 @@ class Glm4vVisionAttention(nn.Module): q, k, v = (rearrange(x, "s b ... -> b s ...").contiguous() for x in (q, k, v)) if rotary_pos_emb is not None: - q = apply_rotary_pos_emb_vision(q, rotary_pos_emb) - k = apply_rotary_pos_emb_vision(k, rotary_pos_emb) + # [2 * b, s, heads, head_dim] + qk_concat = torch.cat([q, k], dim=0) + qk_rotated = apply_rotary_pos_emb_vision(qk_concat, rotary_pos_emb) + q, k = torch.chunk(qk_rotated, 2, dim=0) if self.attn_backend == _Backend.FLASH_ATTN: # from vllm_flash_attn.flash_attn_interface import ( @@ -341,8 +343,8 @@ class Glm4vVisionAttention(nn.Module): ) context_layer = rearrange(output, - "(b s) ... -> b s ...", - b=batch_size) + "(b s) h d -> s b (h d)", + b=batch_size).contiguous() elif self.attn_backend == _Backend.TORCH_SDPA: # Execute attention entry by entry for speed & less VRAM. outputs = [] @@ -361,6 +363,8 @@ class Glm4vVisionAttention(nn.Module): output_i = rearrange(output_i, "b h s d -> b s h d ") outputs.append(output_i) context_layer = torch.cat(outputs, dim=1) + context_layer = rearrange(context_layer, + "b s h d -> s b (h d)").contiguous() elif self.attn_backend == _Backend.XFORMERS: from xformers import ops as xops from xformers.ops.fmha.attn_bias import BlockDiagonalMask @@ -371,9 +375,8 @@ class Glm4vVisionAttention(nn.Module): context_layer = xops.memory_efficient_attention_forward( q, k, v, attn_bias=attn_bias, p=0, scale=None) - - context_layer = rearrange(context_layer, - "b s h d -> s b (h d)").contiguous() + context_layer = rearrange(context_layer, + "b s h d -> s b (h d)").contiguous() output, _ = self.proj(context_layer) return output diff --git a/vllm/model_executor/models/qwen2_vl.py b/vllm/model_executor/models/qwen2_vl.py index 7f361678ba72e..dd4e7731e0b08 100644 --- a/vllm/model_executor/models/qwen2_vl.py +++ b/vllm/model_executor/models/qwen2_vl.py @@ -377,8 +377,10 @@ class Qwen2VisionAttention(nn.Module): q, k, v = (rearrange(x, "s b ... -> b s ...").contiguous() for x in (q, k, v)) if rotary_pos_emb is not None: - q = apply_rotary_pos_emb_vision(q, rotary_pos_emb) - k = apply_rotary_pos_emb_vision(k, rotary_pos_emb) + # [2 * b, s, heads, head_dim] + qk_concat = torch.cat([q, k], dim=0) + qk_rotated = apply_rotary_pos_emb_vision(qk_concat, rotary_pos_emb) + q, k = torch.chunk(qk_rotated, 2, dim=0) if self.is_flash_attn_backend: if self.attn_backend == _Backend.ROCM_AITER_FA: @@ -402,8 +404,8 @@ class Qwen2VisionAttention(nn.Module): causal=False) context_layer = rearrange(output, - "(b s) ... -> b s ...", - b=batch_size) + "(b s) h d -> s b (h d)", + b=batch_size).contiguous() elif self.attn_backend == _Backend.TORCH_SDPA: # Execute attention entry by entry for speed & less VRAM. outputs = [] @@ -422,6 +424,8 @@ class Qwen2VisionAttention(nn.Module): output_i = rearrange(output_i, "b h s d -> b s h d ") outputs.append(output_i) context_layer = torch.cat(outputs, dim=1) + context_layer = rearrange(context_layer, + "b s h d -> s b (h d)").contiguous() elif self.attn_backend == _Backend.XFORMERS: from xformers import ops as xops from xformers.ops.fmha.attn_bias import BlockDiagonalMask @@ -432,8 +436,8 @@ class Qwen2VisionAttention(nn.Module): context_layer = xops.memory_efficient_attention_forward( q, k, v, attn_bias=attn_bias, p=0, scale=None) - context_layer = rearrange(context_layer, - "b s h d -> s b (h d)").contiguous() + context_layer = rearrange(context_layer, + "b s h d -> s b (h d)").contiguous() output, _ = self.proj(context_layer) return output