From 962d2c63495e930cdd3b59479dce1de48be57ecd Mon Sep 17 00:00:00 2001 From: Michael Goin Date: Sun, 20 Oct 2024 01:29:14 -0400 Subject: [PATCH] [Model][Pixtral] Use memory_efficient_attention for PixtralHFVision (#9520) --- vllm/model_executor/models/pixtral.py | 62 +++++++++------------------ 1 file changed, 21 insertions(+), 41 deletions(-) diff --git a/vllm/model_executor/models/pixtral.py b/vllm/model_executor/models/pixtral.py index b07ac5baecda9..13c5149a63919 100644 --- a/vllm/model_executor/models/pixtral.py +++ b/vllm/model_executor/models/pixtral.py @@ -13,8 +13,7 @@ from transformers import PixtralVisionConfig, PretrainedConfig from transformers.models.pixtral.image_processing_pixtral import ( _num_image_tokens) from transformers.models.pixtral.modeling_pixtral import ( - PixtralRotaryEmbedding, apply_rotary_pos_emb, - generate_block_attention_mask, position_ids_in_meshgrid) + PixtralRotaryEmbedding, apply_rotary_pos_emb, position_ids_in_meshgrid) from xformers.ops.fmha import memory_efficient_attention from xformers.ops.fmha.attn_bias import BlockDiagonalMask @@ -813,48 +812,30 @@ class PixtralHFAttention(nn.Module): def forward( self, hidden_states: torch.Tensor, - attention_mask: torch.Tensor, + attention_mask: BlockDiagonalMask, position_embeddings: torch.Tensor, ) -> Tuple[torch.Tensor, Optional[torch.Tensor]]: - """Input shape: Batch x Time x Channel""" + batch, patches, _ = hidden_states.size() - batch_size, patches, _ = hidden_states.size() - - query_states = self.q_proj(hidden_states) - key_states = self.k_proj(hidden_states) - value_states = self.v_proj(hidden_states) - - query_states = query_states.view(batch_size, patches, self.n_heads, - self.head_dim).transpose(1, 2) - key_states = key_states.view(batch_size, patches, self.n_heads, - self.head_dim).transpose(1, 2) - value_states = value_states.view(batch_size, patches, self.n_heads, - self.head_dim).transpose(1, 2) + q = self.q_proj(hidden_states) + k = self.k_proj(hidden_states) + v = self.v_proj(hidden_states) + # Transpose q and k to apply HF's Rotary Position Embedding + q = q.view(batch, patches, self.n_heads, self.head_dim).transpose(1, 2) + k = k.view(batch, patches, self.n_heads, self.head_dim).transpose(1, 2) cos, sin = position_embeddings - query_states, key_states = apply_rotary_pos_emb(query_states, - key_states, - cos, - sin, - unsqueeze_dim=0) + q, k = apply_rotary_pos_emb(q, k, cos, sin, unsqueeze_dim=0) - attn_weights = torch.matmul(query_states, key_states.transpose( - 2, 3)) * self.scale + # Transpose q and k back for attention + q = q.transpose(1, 2).contiguous() + k = k.transpose(1, 2).contiguous() + v = v.reshape(batch, patches, self.n_heads, self.head_dim) - if attention_mask is not None: - attn_weights = attn_weights + attention_mask + out = memory_efficient_attention(q, k, v, attn_bias=attention_mask) + out = out.reshape(batch, patches, self.n_heads * self.head_dim) - # upcast attention to fp32 - attn_weights = nn.functional.softmax(attn_weights, - dim=-1, - dtype=torch.float32).to( - query_states.dtype) - attn_output = torch.matmul(attn_weights, value_states) - - attn_output = attn_output.transpose(1, 2).contiguous() - attn_output = attn_output.reshape(batch_size, patches, -1) - - return self.o_proj(attn_output) + return self.o_proj(out) class PixtralHFTransformerBlock(nn.Module): @@ -869,7 +850,7 @@ class PixtralHFTransformerBlock(nn.Module): def forward( self, hidden_states: torch.Tensor, - attention_mask: torch.Tensor, + attention_mask: BlockDiagonalMask, position_embeddings: torch.Tensor, ) -> torch.Tensor: r = self.attention.forward(self.attention_norm(hidden_states), @@ -892,7 +873,7 @@ class PixtralHFTransformer(nn.Module): def forward( self, x: torch.Tensor, - attention_mask: torch.Tensor, + attention_mask: BlockDiagonalMask, position_embeddings: torch.Tensor, ) -> torch.Tensor: for layer in self.layers: @@ -953,9 +934,8 @@ class PixtralHFVisionModel(nn.Module): position_embedding = self.patch_positional_embedding( patch_embeds, position_ids) - attention_mask = generate_block_attention_mask( - [p.shape[-2] * p.shape[-1] for p in patch_embeds_list], - patch_embeds) + attention_mask = BlockDiagonalMask.from_seqlens( + [p.shape[-2] * p.shape[-1] for p in patch_embeds_list], ) out = self.transformer(patch_embeds, attention_mask, position_embedding)