Remove xformers requirement for Mistral-format Pixtral and Mistral3 (#21154)

Signed-off-by: Wenchen Lo <charles761013@gmail.com>
This commit is contained in:
Wenchen Lo 2025-07-26 16:20:29 -07:00 committed by GitHub
parent de509ae8eb
commit 6c66f28fa5
No known key found for this signature in database
GPG Key ID: B5690EEEBB952194

View File

@ -671,7 +671,19 @@ class Attention(nn.Module):
v = v.reshape(batch, patches, self.n_heads, self.head_dim)
q, k = apply_rotary_emb_vit(q, k, freqs_cis=freqs_cis)
out = xops.memory_efficient_attention(q, k, v, attn_bias=mask)
if USE_XFORMERS_OPS:
out = xops.memory_efficient_attention(q, k, v, attn_bias=mask)
else:
q = q.transpose(1, 2)
k = k.transpose(1, 2)
v = v.transpose(1, 2)
out = nn.functional.scaled_dot_product_attention(q,
k,
v,
attn_mask=mask)
out = out.transpose(1, 2)
out = out.reshape(batch, patches, self.n_heads * self.head_dim)
return self.wo(out)
@ -814,8 +826,11 @@ class VisionTransformer(nn.Module):
mask = xops.fmha.attn_bias.BlockDiagonalMask.from_seqlens(
[p.shape[-2] * p.shape[-1] for p in patch_embeds_list], )
else:
raise ImportError("Xformers is required for Pixtral inference "
"with the Mistral format")
from transformers.models.pixtral.modeling_pixtral import (
generate_block_attention_mask)
mask = generate_block_attention_mask(
[p.shape[-2] * p.shape[-1] for p in patch_embeds_list],
patch_embeds)
out = self.transformer(patch_embeds, mask=mask, freqs_cis=freqs_cis)
# squeeze dim 0 and split into separate tensors for each image