From 6c66f28fa5dc88ce6f7ab30dfa733f9ddb927d3c Mon Sep 17 00:00:00 2001 From: Wenchen Lo Date: Sat, 26 Jul 2025 16:20:29 -0700 Subject: [PATCH] Remove xformers requirement for Mistral-format Pixtral and Mistral3 (#21154) Signed-off-by: Wenchen Lo --- vllm/model_executor/models/pixtral.py | 21 ++++++++++++++++++--- 1 file changed, 18 insertions(+), 3 deletions(-) diff --git a/vllm/model_executor/models/pixtral.py b/vllm/model_executor/models/pixtral.py index 325a264a2f4c..41eaf372785e 100644 --- a/vllm/model_executor/models/pixtral.py +++ b/vllm/model_executor/models/pixtral.py @@ -671,7 +671,19 @@ class Attention(nn.Module): v = v.reshape(batch, patches, self.n_heads, self.head_dim) q, k = apply_rotary_emb_vit(q, k, freqs_cis=freqs_cis) - out = xops.memory_efficient_attention(q, k, v, attn_bias=mask) + + if USE_XFORMERS_OPS: + out = xops.memory_efficient_attention(q, k, v, attn_bias=mask) + else: + q = q.transpose(1, 2) + k = k.transpose(1, 2) + v = v.transpose(1, 2) + out = nn.functional.scaled_dot_product_attention(q, + k, + v, + attn_mask=mask) + out = out.transpose(1, 2) + out = out.reshape(batch, patches, self.n_heads * self.head_dim) return self.wo(out) @@ -814,8 +826,11 @@ class VisionTransformer(nn.Module): mask = xops.fmha.attn_bias.BlockDiagonalMask.from_seqlens( [p.shape[-2] * p.shape[-1] for p in patch_embeds_list], ) else: - raise ImportError("Xformers is required for Pixtral inference " - "with the Mistral format") + from transformers.models.pixtral.modeling_pixtral import ( + generate_block_attention_mask) + mask = generate_block_attention_mask( + [p.shape[-2] * p.shape[-1] for p in patch_embeds_list], + patch_embeds) out = self.transformer(patch_embeds, mask=mask, freqs_cis=freqs_cis) # squeeze dim 0 and split into separate tensors for each image