mirror of
https://git.datalinker.icu/vllm-project/vllm.git
synced 2025-12-15 04:35:01 +08:00
[BugFix] Fix Siglip2Attention on XPU (#28448)
Signed-off-by: Lin, Fanli <fanli.lin@intel.com>
This commit is contained in:
parent
6c3c0f8235
commit
d5edcb8678
@ -25,6 +25,7 @@ from vllm.model_executor.layers.linear import (
|
|||||||
)
|
)
|
||||||
from vllm.model_executor.layers.quantization import QuantizationConfig
|
from vllm.model_executor.layers.quantization import QuantizationConfig
|
||||||
from vllm.model_executor.model_loader.weight_utils import default_weight_loader
|
from vllm.model_executor.model_loader.weight_utils import default_weight_loader
|
||||||
|
from vllm.platforms import current_platform
|
||||||
|
|
||||||
from .vision import get_vit_attn_backend
|
from .vision import get_vit_attn_backend
|
||||||
|
|
||||||
@ -188,7 +189,7 @@ def apply_rotary_pos_emb(
|
|||||||
) -> tuple[torch.Tensor, torch.Tensor]:
|
) -> tuple[torch.Tensor, torch.Tensor]:
|
||||||
cos = cos.chunk(2, dim=-1)[0].contiguous()
|
cos = cos.chunk(2, dim=-1)[0].contiguous()
|
||||||
sin = sin.chunk(2, dim=-1)[0].contiguous()
|
sin = sin.chunk(2, dim=-1)[0].contiguous()
|
||||||
if is_flash_attn_backend:
|
if is_flash_attn_backend and not current_platform.is_xpu():
|
||||||
from flash_attn.layers.rotary import apply_rotary_emb
|
from flash_attn.layers.rotary import apply_rotary_emb
|
||||||
|
|
||||||
apply_rotary_emb_func = apply_rotary_emb
|
apply_rotary_emb_func = apply_rotary_emb
|
||||||
@ -306,7 +307,13 @@ class Siglip2Attention(nn.Module):
|
|||||||
max_seqlen = (cu_seqlens[1:] - cu_seqlens[:-1]).max().item()
|
max_seqlen = (cu_seqlens[1:] - cu_seqlens[:-1]).max().item()
|
||||||
if self.is_flash_attn_backend:
|
if self.is_flash_attn_backend:
|
||||||
attn_output = self.flash_attn_varlen_func(
|
attn_output = self.flash_attn_varlen_func(
|
||||||
queries, keys, values, cu_seqlens, cu_seqlens, max_seqlen, max_seqlen
|
queries,
|
||||||
|
keys,
|
||||||
|
values,
|
||||||
|
cu_seqlens_q=cu_seqlens,
|
||||||
|
cu_seqlens_k=cu_seqlens,
|
||||||
|
max_seqlen_q=max_seqlen,
|
||||||
|
max_seqlen_k=max_seqlen,
|
||||||
).reshape(seq_length, -1)
|
).reshape(seq_length, -1)
|
||||||
elif self.attn_backend == AttentionBackendEnum.TORCH_SDPA:
|
elif self.attn_backend == AttentionBackendEnum.TORCH_SDPA:
|
||||||
# Execute attention entry by entry for speed & less VRAM.
|
# Execute attention entry by entry for speed & less VRAM.
|
||||||
|
|||||||
Loading…
x
Reference in New Issue
Block a user