From d5edcb86781ea56f1eb0c9c5d7482a7cae00ec17 Mon Sep 17 00:00:00 2001 From: Fanli Lin Date: Wed, 12 Nov 2025 02:18:02 +0800 Subject: [PATCH] [BugFix] Fix Siglip2Attention on XPU (#28448) Signed-off-by: Lin, Fanli --- vllm/model_executor/models/siglip2navit.py | 11 +++++++++-- 1 file changed, 9 insertions(+), 2 deletions(-) diff --git a/vllm/model_executor/models/siglip2navit.py b/vllm/model_executor/models/siglip2navit.py index c20bcd975ca30..29dd164ad37fd 100644 --- a/vllm/model_executor/models/siglip2navit.py +++ b/vllm/model_executor/models/siglip2navit.py @@ -25,6 +25,7 @@ from vllm.model_executor.layers.linear import ( ) from vllm.model_executor.layers.quantization import QuantizationConfig from vllm.model_executor.model_loader.weight_utils import default_weight_loader +from vllm.platforms import current_platform from .vision import get_vit_attn_backend @@ -188,7 +189,7 @@ def apply_rotary_pos_emb( ) -> tuple[torch.Tensor, torch.Tensor]: cos = cos.chunk(2, dim=-1)[0].contiguous() sin = sin.chunk(2, dim=-1)[0].contiguous() - if is_flash_attn_backend: + if is_flash_attn_backend and not current_platform.is_xpu(): from flash_attn.layers.rotary import apply_rotary_emb apply_rotary_emb_func = apply_rotary_emb @@ -306,7 +307,13 @@ class Siglip2Attention(nn.Module): max_seqlen = (cu_seqlens[1:] - cu_seqlens[:-1]).max().item() if self.is_flash_attn_backend: attn_output = self.flash_attn_varlen_func( - queries, keys, values, cu_seqlens, cu_seqlens, max_seqlen, max_seqlen + queries, + keys, + values, + cu_seqlens_q=cu_seqlens, + cu_seqlens_k=cu_seqlens, + max_seqlen_q=max_seqlen, + max_seqlen_k=max_seqlen, ).reshape(seq_length, -1) elif self.attn_backend == AttentionBackendEnum.TORCH_SDPA: # Execute attention entry by entry for speed & less VRAM.