mirror of
https://git.datalinker.icu/vllm-project/vllm.git
synced 2025-12-09 23:15:31 +08:00
[BugFix] Fix Siglip2Attention on XPU (#28448)
Signed-off-by: Lin, Fanli <fanli.lin@intel.com>
This commit is contained in:
parent
6c3c0f8235
commit
d5edcb8678
@ -25,6 +25,7 @@ from vllm.model_executor.layers.linear import (
|
||||
)
|
||||
from vllm.model_executor.layers.quantization import QuantizationConfig
|
||||
from vllm.model_executor.model_loader.weight_utils import default_weight_loader
|
||||
from vllm.platforms import current_platform
|
||||
|
||||
from .vision import get_vit_attn_backend
|
||||
|
||||
@ -188,7 +189,7 @@ def apply_rotary_pos_emb(
|
||||
) -> tuple[torch.Tensor, torch.Tensor]:
|
||||
cos = cos.chunk(2, dim=-1)[0].contiguous()
|
||||
sin = sin.chunk(2, dim=-1)[0].contiguous()
|
||||
if is_flash_attn_backend:
|
||||
if is_flash_attn_backend and not current_platform.is_xpu():
|
||||
from flash_attn.layers.rotary import apply_rotary_emb
|
||||
|
||||
apply_rotary_emb_func = apply_rotary_emb
|
||||
@ -306,7 +307,13 @@ class Siglip2Attention(nn.Module):
|
||||
max_seqlen = (cu_seqlens[1:] - cu_seqlens[:-1]).max().item()
|
||||
if self.is_flash_attn_backend:
|
||||
attn_output = self.flash_attn_varlen_func(
|
||||
queries, keys, values, cu_seqlens, cu_seqlens, max_seqlen, max_seqlen
|
||||
queries,
|
||||
keys,
|
||||
values,
|
||||
cu_seqlens_q=cu_seqlens,
|
||||
cu_seqlens_k=cu_seqlens,
|
||||
max_seqlen_q=max_seqlen,
|
||||
max_seqlen_k=max_seqlen,
|
||||
).reshape(seq_length, -1)
|
||||
elif self.attn_backend == AttentionBackendEnum.TORCH_SDPA:
|
||||
# Execute attention entry by entry for speed & less VRAM.
|
||||
|
||||
Loading…
x
Reference in New Issue
Block a user