mirror of
https://git.datalinker.icu/vllm-project/vllm.git
synced 2025-12-11 06:25:35 +08:00
Remove xformers requirement for Mistral-format Pixtral and Mistral3 (#21154)
Signed-off-by: Wenchen Lo <charles761013@gmail.com>
This commit is contained in:
parent
de509ae8eb
commit
6c66f28fa5
@ -671,7 +671,19 @@ class Attention(nn.Module):
|
||||
v = v.reshape(batch, patches, self.n_heads, self.head_dim)
|
||||
|
||||
q, k = apply_rotary_emb_vit(q, k, freqs_cis=freqs_cis)
|
||||
out = xops.memory_efficient_attention(q, k, v, attn_bias=mask)
|
||||
|
||||
if USE_XFORMERS_OPS:
|
||||
out = xops.memory_efficient_attention(q, k, v, attn_bias=mask)
|
||||
else:
|
||||
q = q.transpose(1, 2)
|
||||
k = k.transpose(1, 2)
|
||||
v = v.transpose(1, 2)
|
||||
out = nn.functional.scaled_dot_product_attention(q,
|
||||
k,
|
||||
v,
|
||||
attn_mask=mask)
|
||||
out = out.transpose(1, 2)
|
||||
|
||||
out = out.reshape(batch, patches, self.n_heads * self.head_dim)
|
||||
return self.wo(out)
|
||||
|
||||
@ -814,8 +826,11 @@ class VisionTransformer(nn.Module):
|
||||
mask = xops.fmha.attn_bias.BlockDiagonalMask.from_seqlens(
|
||||
[p.shape[-2] * p.shape[-1] for p in patch_embeds_list], )
|
||||
else:
|
||||
raise ImportError("Xformers is required for Pixtral inference "
|
||||
"with the Mistral format")
|
||||
from transformers.models.pixtral.modeling_pixtral import (
|
||||
generate_block_attention_mask)
|
||||
mask = generate_block_attention_mask(
|
||||
[p.shape[-2] * p.shape[-1] for p in patch_embeds_list],
|
||||
patch_embeds)
|
||||
out = self.transformer(patch_embeds, mask=mask, freqs_cis=freqs_cis)
|
||||
|
||||
# squeeze dim 0 and split into separate tensors for each image
|
||||
|
||||
Loading…
x
Reference in New Issue
Block a user