[BugFix] Fix Siglip2Attention on XPU (#28448)

Signed-off-by: Lin, Fanli <fanli.lin@intel.com>
This commit is contained in:
Fanli Lin 2025-11-12 02:18:02 +08:00 committed by GitHub
parent 6c3c0f8235
commit d5edcb8678
No known key found for this signature in database
GPG Key ID: B5690EEEBB952194

View File

@ -25,6 +25,7 @@ from vllm.model_executor.layers.linear import (
)
from vllm.model_executor.layers.quantization import QuantizationConfig
from vllm.model_executor.model_loader.weight_utils import default_weight_loader
from vllm.platforms import current_platform
from .vision import get_vit_attn_backend
@ -188,7 +189,7 @@ def apply_rotary_pos_emb(
) -> tuple[torch.Tensor, torch.Tensor]:
cos = cos.chunk(2, dim=-1)[0].contiguous()
sin = sin.chunk(2, dim=-1)[0].contiguous()
if is_flash_attn_backend:
if is_flash_attn_backend and not current_platform.is_xpu():
from flash_attn.layers.rotary import apply_rotary_emb
apply_rotary_emb_func = apply_rotary_emb
@ -306,7 +307,13 @@ class Siglip2Attention(nn.Module):
max_seqlen = (cu_seqlens[1:] - cu_seqlens[:-1]).max().item()
if self.is_flash_attn_backend:
attn_output = self.flash_attn_varlen_func(
queries, keys, values, cu_seqlens, cu_seqlens, max_seqlen, max_seqlen
queries,
keys,
values,
cu_seqlens_q=cu_seqlens,
cu_seqlens_k=cu_seqlens,
max_seqlen_q=max_seqlen,
max_seqlen_k=max_seqlen,
).reshape(seq_length, -1)
elif self.attn_backend == AttentionBackendEnum.TORCH_SDPA:
# Execute attention entry by entry for speed & less VRAM.