[Multimodal][XPU]Enable vision attn backend for xpu platform (#27525)

Signed-off-by: Yan Ma <yan.ma@intel.com>
Signed-off-by: Kunshang Ji <kunshang.ji@intel.com>
Co-authored-by: Yejing Lai <yejing.lai@intel.com>
Co-authored-by: Guancheng Fu <110874468+gc-fu@users.noreply.github.com>
Co-authored-by: Kunshang Ji <kunshang.ji@intel.com>
This commit is contained in:
Yan Ma 2025-11-01 12:45:02 +08:00 committed by GitHub
parent 3a5de7d2d6
commit 7e2729b57e
No known key found for this signature in database
GPG Key ID: B5690EEEBB952194
6 changed files with 88 additions and 51 deletions

View File

@ -270,21 +270,23 @@ class ipex_ops:
@staticmethod @staticmethod
def flash_attn_varlen_func( def flash_attn_varlen_func(
out: torch.Tensor,
q: torch.Tensor, q: torch.Tensor,
k: torch.Tensor, k: torch.Tensor,
v: torch.Tensor, v: torch.Tensor,
cu_seqlens_q: torch.Tensor, cu_seqlens_q: torch.Tensor,
seqused_k: torch.Tensor, # we don't support this in ipex kernel
max_seqlen_q: int, max_seqlen_q: int,
max_seqlen_k: int, max_seqlen_k: int,
softmax_scale: float, softmax_scale: float | None = None,
causal: bool, causal: bool = False,
block_table: torch.Tensor, out: torch.Tensor | None = None,
alibi_slopes: torch.Tensor | None, block_table: torch.Tensor | None = None,
alibi_slopes: torch.Tensor | None = None,
window_size: list[int] | None = None, window_size: list[int] | None = None,
softcap: float | None = 0.0, softcap: float | None = 0.0,
seqused_k: torch.Tensor | None = None,
cu_seqlens_k: torch.Tensor | None = None, cu_seqlens_k: torch.Tensor | None = None,
# passed in qwen vl
dropout_p: float = 0.0,
# The following parameters are not used in ipex kernel currently, # The following parameters are not used in ipex kernel currently,
# we keep API compatible to CUDA's. # we keep API compatible to CUDA's.
scheduler_metadata=None, scheduler_metadata=None,
@ -295,31 +297,63 @@ class ipex_ops:
num_splits=0, num_splits=0,
s_aux: torch.Tensor | None = None, s_aux: torch.Tensor | None = None,
): ):
if out is None:
out = torch.empty(q.shape, dtype=q.dtype, device=q.device)
real_window_size: tuple[int, int] real_window_size: tuple[int, int]
if window_size is None: if window_size is None:
real_window_size = (-1, -1) real_window_size = (-1, -1)
else: else:
assert len(window_size) == 2 assert len(window_size) == 2
real_window_size = (window_size[0], window_size[1]) real_window_size = (window_size[0], window_size[1])
return ipex.llm.modules.PagedAttention.flash_attn_varlen_func(
out, if block_table is None:
q.contiguous(), assert cu_seqlens_k is not None, (
k, "cu_seqlens_k can't be None when calling varlen_attention."
v, )
cu_seqlens_q, if softmax_scale is None:
seqused_k, softmax_scale = q.shape[-1] ** (-0.5)
max_seqlen_q, ipex_ops.varlen_attention(
max_seqlen_k, q.contiguous(),
softmax_scale, k.contiguous(),
causal, v.contiguous(),
block_table, out,
alibi_slopes, cu_seqlens_q,
softcap=softcap, cu_seqlens_k,
window_size_left=real_window_size[0], None,
window_size_right=real_window_size[1], max_seqlen_q,
k_scale=1.0, max_seqlen_k,
v_scale=1.0, 0.0,
) softmax_scale,
False,
causal,
False,
None,
real_window_size[0],
real_window_size[1],
-1,
)
return out
else:
return ipex.llm.modules.PagedAttention.flash_attn_varlen_func(
out,
q.contiguous(),
k,
v,
cu_seqlens_q,
seqused_k,
max_seqlen_q,
max_seqlen_k,
softmax_scale,
causal,
block_table,
alibi_slopes,
sink=s_aux,
softcap=softcap,
window_size_left=real_window_size[0],
window_size_right=real_window_size[1],
k_scale=1.0,
v_scale=1.0,
)
@staticmethod @staticmethod
def get_scheduler_metadata( def get_scheduler_metadata(

View File

@ -123,6 +123,11 @@ def maybe_get_vit_flash_attn_backend(
): ):
attn_backend = _Backend.FLASH_ATTN attn_backend = _Backend.FLASH_ATTN
use_upstream_fa = True use_upstream_fa = True
elif current_platform.is_xpu():
assert attn_backend == _Backend.FLASH_ATTN, (
"XPU platform only supports FLASH_ATTN as vision attention backend."
)
use_upstream_fa = False
else: else:
return _Backend.TORCH_SDPA, None return _Backend.TORCH_SDPA, None
@ -133,7 +138,7 @@ def maybe_get_vit_flash_attn_backend(
if use_upstream_fa: if use_upstream_fa:
from flash_attn import flash_attn_varlen_func from flash_attn import flash_attn_varlen_func
else: else:
from vllm.vllm_flash_attn import flash_attn_varlen_func from vllm.attention.utils.fa_utils import flash_attn_varlen_func
else: else:
flash_attn_varlen_func = None flash_attn_varlen_func = None
@ -521,22 +526,18 @@ class MultiHeadAttention(nn.Module):
# If vllm native fa is selected, we use it directly. # If vllm native fa is selected, we use it directly.
use_upstream_fa = False use_upstream_fa = False
if current_platform.is_xpu(): self.attn_backend = (
# currently, only torch_sdpa is supported on xpu backend
self.attn_backend = _Backend.TORCH_SDPA if backend
else: in {
self.attn_backend = ( _Backend.TORCH_SDPA,
backend _Backend.XFORMERS,
if backend _Backend.PALLAS,
in { _Backend.ROCM_AITER_FA,
_Backend.TORCH_SDPA, _Backend.FLASH_ATTN,
_Backend.XFORMERS, }
_Backend.PALLAS, else _Backend.TORCH_SDPA
_Backend.ROCM_AITER_FA, )
_Backend.FLASH_ATTN,
}
else _Backend.TORCH_SDPA
)
self.attn_backend, self._flash_attn_varlen_func = ( self.attn_backend, self._flash_attn_varlen_func = (
maybe_get_vit_flash_attn_backend( maybe_get_vit_flash_attn_backend(

View File

@ -70,7 +70,7 @@ def flash_attn_maxseqlen_wrapper(
if use_upstream_fa: if use_upstream_fa:
from flash_attn import flash_attn_varlen_func from flash_attn import flash_attn_varlen_func
else: else:
from vllm.vllm_flash_attn import flash_attn_varlen_func from vllm.attention.utils.fa_utils import flash_attn_varlen_func
q, k, v = (einops.rearrange(x, "b s ... -> (b s) ...") for x in [q, k, v]) q, k, v = (einops.rearrange(x, "b s ... -> (b s) ...") for x in [q, k, v])
output = flash_attn_varlen_func( output = flash_attn_varlen_func(
q, q,

View File

@ -364,6 +364,8 @@ class Qwen2_5_VisionAttention(nn.Module):
if current_platform.is_rocm() and self.attn_backend == _Backend.FLASH_ATTN: if current_platform.is_rocm() and self.attn_backend == _Backend.FLASH_ATTN:
self.use_upstream_fa = True self.use_upstream_fa = True
if current_platform.is_xpu():
self.use_upstream_fa = False
self.is_flash_attn_backend = self.attn_backend in { self.is_flash_attn_backend = self.attn_backend in {
_Backend.FLASH_ATTN, _Backend.FLASH_ATTN,
_Backend.ROCM_AITER_FA, _Backend.ROCM_AITER_FA,
@ -856,10 +858,7 @@ class Qwen2_5_VisionTransformer(nn.Module):
) -> tuple[torch.Tensor, torch.Tensor]: ) -> tuple[torch.Tensor, torch.Tensor]:
max_seqlen = torch.zeros([], device=cu_seqlens.device) max_seqlen = torch.zeros([], device=cu_seqlens.device)
seqlens = torch.zeros(1, device=cu_seqlens.device) seqlens = torch.zeros(1, device=cu_seqlens.device)
if ( if self.attn_backend in {_Backend.FLASH_ATTN, _Backend.ROCM_AITER_FA}:
self.attn_backend == _Backend.FLASH_ATTN
or self.attn_backend == _Backend.ROCM_AITER_FA
):
max_seqlen = (cu_seqlens[1:] - cu_seqlens[:-1]).max() max_seqlen = (cu_seqlens[1:] - cu_seqlens[:-1]).max()
elif self.attn_backend == _Backend.XFORMERS: elif self.attn_backend == _Backend.XFORMERS:
seqlens = cu_seqlens[1:] - cu_seqlens[:-1] seqlens = cu_seqlens[1:] - cu_seqlens[:-1]

View File

@ -789,10 +789,7 @@ class Qwen2VisionTransformer(nn.Module):
self, cu_seqlens: torch.Tensor self, cu_seqlens: torch.Tensor
) -> tuple[int | None, list[int] | None]: ) -> tuple[int | None, list[int] | None]:
max_seqlen, seqlens = None, None max_seqlen, seqlens = None, None
if ( if self.attn_backend in {_Backend.FLASH_ATTN, _Backend.ROCM_AITER_FA}:
self.attn_backend == _Backend.FLASH_ATTN
or self.attn_backend == _Backend.ROCM_AITER_FA
):
max_seqlen = (cu_seqlens[1:] - cu_seqlens[:-1]).max().item() max_seqlen = (cu_seqlens[1:] - cu_seqlens[:-1]).max().item()
elif self.attn_backend == _Backend.XFORMERS: elif self.attn_backend == _Backend.XFORMERS:
seqlens = (cu_seqlens[1:] - cu_seqlens[:-1]).tolist() seqlens = (cu_seqlens[1:] - cu_seqlens[:-1]).tolist()

View File

@ -115,6 +115,12 @@ class XPUPlatform(Platform):
device_props = torch.xpu.get_device_properties(device_id) device_props = torch.xpu.get_device_properties(device_id)
return device_props.total_memory return device_props.total_memory
@classmethod
def get_vit_attn_backend(cls, head_size: int, dtype: torch.dtype) -> _Backend:
from vllm.attention.backends.registry import _Backend
return _Backend.FLASH_ATTN
@classmethod @classmethod
def inference_mode(cls): def inference_mode(cls):
return torch.no_grad() return torch.no_grad()