mirror of
https://git.datalinker.icu/vllm-project/vllm.git
synced 2026-05-26 05:51:19 +08:00
[Model][Pixtral] Use memory_efficient_attention for PixtralHFVision (#9520)
This commit is contained in:
parent
5b59fe0f08
commit
962d2c6349
@ -13,8 +13,7 @@ from transformers import PixtralVisionConfig, PretrainedConfig
|
|||||||
from transformers.models.pixtral.image_processing_pixtral import (
|
from transformers.models.pixtral.image_processing_pixtral import (
|
||||||
_num_image_tokens)
|
_num_image_tokens)
|
||||||
from transformers.models.pixtral.modeling_pixtral import (
|
from transformers.models.pixtral.modeling_pixtral import (
|
||||||
PixtralRotaryEmbedding, apply_rotary_pos_emb,
|
PixtralRotaryEmbedding, apply_rotary_pos_emb, position_ids_in_meshgrid)
|
||||||
generate_block_attention_mask, position_ids_in_meshgrid)
|
|
||||||
from xformers.ops.fmha import memory_efficient_attention
|
from xformers.ops.fmha import memory_efficient_attention
|
||||||
from xformers.ops.fmha.attn_bias import BlockDiagonalMask
|
from xformers.ops.fmha.attn_bias import BlockDiagonalMask
|
||||||
|
|
||||||
@ -813,48 +812,30 @@ class PixtralHFAttention(nn.Module):
|
|||||||
def forward(
|
def forward(
|
||||||
self,
|
self,
|
||||||
hidden_states: torch.Tensor,
|
hidden_states: torch.Tensor,
|
||||||
attention_mask: torch.Tensor,
|
attention_mask: BlockDiagonalMask,
|
||||||
position_embeddings: torch.Tensor,
|
position_embeddings: torch.Tensor,
|
||||||
) -> Tuple[torch.Tensor, Optional[torch.Tensor]]:
|
) -> Tuple[torch.Tensor, Optional[torch.Tensor]]:
|
||||||
"""Input shape: Batch x Time x Channel"""
|
batch, patches, _ = hidden_states.size()
|
||||||
|
|
||||||
batch_size, patches, _ = hidden_states.size()
|
q = self.q_proj(hidden_states)
|
||||||
|
k = self.k_proj(hidden_states)
|
||||||
query_states = self.q_proj(hidden_states)
|
v = self.v_proj(hidden_states)
|
||||||
key_states = self.k_proj(hidden_states)
|
|
||||||
value_states = self.v_proj(hidden_states)
|
|
||||||
|
|
||||||
query_states = query_states.view(batch_size, patches, self.n_heads,
|
|
||||||
self.head_dim).transpose(1, 2)
|
|
||||||
key_states = key_states.view(batch_size, patches, self.n_heads,
|
|
||||||
self.head_dim).transpose(1, 2)
|
|
||||||
value_states = value_states.view(batch_size, patches, self.n_heads,
|
|
||||||
self.head_dim).transpose(1, 2)
|
|
||||||
|
|
||||||
|
# Transpose q and k to apply HF's Rotary Position Embedding
|
||||||
|
q = q.view(batch, patches, self.n_heads, self.head_dim).transpose(1, 2)
|
||||||
|
k = k.view(batch, patches, self.n_heads, self.head_dim).transpose(1, 2)
|
||||||
cos, sin = position_embeddings
|
cos, sin = position_embeddings
|
||||||
query_states, key_states = apply_rotary_pos_emb(query_states,
|
q, k = apply_rotary_pos_emb(q, k, cos, sin, unsqueeze_dim=0)
|
||||||
key_states,
|
|
||||||
cos,
|
|
||||||
sin,
|
|
||||||
unsqueeze_dim=0)
|
|
||||||
|
|
||||||
attn_weights = torch.matmul(query_states, key_states.transpose(
|
# Transpose q and k back for attention
|
||||||
2, 3)) * self.scale
|
q = q.transpose(1, 2).contiguous()
|
||||||
|
k = k.transpose(1, 2).contiguous()
|
||||||
|
v = v.reshape(batch, patches, self.n_heads, self.head_dim)
|
||||||
|
|
||||||
if attention_mask is not None:
|
out = memory_efficient_attention(q, k, v, attn_bias=attention_mask)
|
||||||
attn_weights = attn_weights + attention_mask
|
out = out.reshape(batch, patches, self.n_heads * self.head_dim)
|
||||||
|
|
||||||
# upcast attention to fp32
|
return self.o_proj(out)
|
||||||
attn_weights = nn.functional.softmax(attn_weights,
|
|
||||||
dim=-1,
|
|
||||||
dtype=torch.float32).to(
|
|
||||||
query_states.dtype)
|
|
||||||
attn_output = torch.matmul(attn_weights, value_states)
|
|
||||||
|
|
||||||
attn_output = attn_output.transpose(1, 2).contiguous()
|
|
||||||
attn_output = attn_output.reshape(batch_size, patches, -1)
|
|
||||||
|
|
||||||
return self.o_proj(attn_output)
|
|
||||||
|
|
||||||
|
|
||||||
class PixtralHFTransformerBlock(nn.Module):
|
class PixtralHFTransformerBlock(nn.Module):
|
||||||
@ -869,7 +850,7 @@ class PixtralHFTransformerBlock(nn.Module):
|
|||||||
def forward(
|
def forward(
|
||||||
self,
|
self,
|
||||||
hidden_states: torch.Tensor,
|
hidden_states: torch.Tensor,
|
||||||
attention_mask: torch.Tensor,
|
attention_mask: BlockDiagonalMask,
|
||||||
position_embeddings: torch.Tensor,
|
position_embeddings: torch.Tensor,
|
||||||
) -> torch.Tensor:
|
) -> torch.Tensor:
|
||||||
r = self.attention.forward(self.attention_norm(hidden_states),
|
r = self.attention.forward(self.attention_norm(hidden_states),
|
||||||
@ -892,7 +873,7 @@ class PixtralHFTransformer(nn.Module):
|
|||||||
def forward(
|
def forward(
|
||||||
self,
|
self,
|
||||||
x: torch.Tensor,
|
x: torch.Tensor,
|
||||||
attention_mask: torch.Tensor,
|
attention_mask: BlockDiagonalMask,
|
||||||
position_embeddings: torch.Tensor,
|
position_embeddings: torch.Tensor,
|
||||||
) -> torch.Tensor:
|
) -> torch.Tensor:
|
||||||
for layer in self.layers:
|
for layer in self.layers:
|
||||||
@ -953,9 +934,8 @@ class PixtralHFVisionModel(nn.Module):
|
|||||||
|
|
||||||
position_embedding = self.patch_positional_embedding(
|
position_embedding = self.patch_positional_embedding(
|
||||||
patch_embeds, position_ids)
|
patch_embeds, position_ids)
|
||||||
attention_mask = generate_block_attention_mask(
|
attention_mask = BlockDiagonalMask.from_seqlens(
|
||||||
[p.shape[-2] * p.shape[-1] for p in patch_embeds_list],
|
[p.shape[-2] * p.shape[-1] for p in patch_embeds_list], )
|
||||||
patch_embeds)
|
|
||||||
out = self.transformer(patch_embeds, attention_mask,
|
out = self.transformer(patch_embeds, attention_mask,
|
||||||
position_embedding)
|
position_embedding)
|
||||||
|
|
||||||
|
|||||||
Loading…
x
Reference in New Issue
Block a user