mirror of
https://git.datalinker.icu/vllm-project/vllm.git
synced 2025-12-10 05:25:00 +08:00
[Multimodal][XPU]Enable vision attn backend for xpu platform (#27525)
Signed-off-by: Yan Ma <yan.ma@intel.com> Signed-off-by: Kunshang Ji <kunshang.ji@intel.com> Co-authored-by: Yejing Lai <yejing.lai@intel.com> Co-authored-by: Guancheng Fu <110874468+gc-fu@users.noreply.github.com> Co-authored-by: Kunshang Ji <kunshang.ji@intel.com>
This commit is contained in:
parent
3a5de7d2d6
commit
7e2729b57e
@ -270,21 +270,23 @@ class ipex_ops:
|
|||||||
|
|
||||||
@staticmethod
|
@staticmethod
|
||||||
def flash_attn_varlen_func(
|
def flash_attn_varlen_func(
|
||||||
out: torch.Tensor,
|
|
||||||
q: torch.Tensor,
|
q: torch.Tensor,
|
||||||
k: torch.Tensor,
|
k: torch.Tensor,
|
||||||
v: torch.Tensor,
|
v: torch.Tensor,
|
||||||
cu_seqlens_q: torch.Tensor,
|
cu_seqlens_q: torch.Tensor,
|
||||||
seqused_k: torch.Tensor, # we don't support this in ipex kernel
|
|
||||||
max_seqlen_q: int,
|
max_seqlen_q: int,
|
||||||
max_seqlen_k: int,
|
max_seqlen_k: int,
|
||||||
softmax_scale: float,
|
softmax_scale: float | None = None,
|
||||||
causal: bool,
|
causal: bool = False,
|
||||||
block_table: torch.Tensor,
|
out: torch.Tensor | None = None,
|
||||||
alibi_slopes: torch.Tensor | None,
|
block_table: torch.Tensor | None = None,
|
||||||
|
alibi_slopes: torch.Tensor | None = None,
|
||||||
window_size: list[int] | None = None,
|
window_size: list[int] | None = None,
|
||||||
softcap: float | None = 0.0,
|
softcap: float | None = 0.0,
|
||||||
|
seqused_k: torch.Tensor | None = None,
|
||||||
cu_seqlens_k: torch.Tensor | None = None,
|
cu_seqlens_k: torch.Tensor | None = None,
|
||||||
|
# passed in qwen vl
|
||||||
|
dropout_p: float = 0.0,
|
||||||
# The following parameters are not used in ipex kernel currently,
|
# The following parameters are not used in ipex kernel currently,
|
||||||
# we keep API compatible to CUDA's.
|
# we keep API compatible to CUDA's.
|
||||||
scheduler_metadata=None,
|
scheduler_metadata=None,
|
||||||
@ -295,31 +297,63 @@ class ipex_ops:
|
|||||||
num_splits=0,
|
num_splits=0,
|
||||||
s_aux: torch.Tensor | None = None,
|
s_aux: torch.Tensor | None = None,
|
||||||
):
|
):
|
||||||
|
if out is None:
|
||||||
|
out = torch.empty(q.shape, dtype=q.dtype, device=q.device)
|
||||||
real_window_size: tuple[int, int]
|
real_window_size: tuple[int, int]
|
||||||
if window_size is None:
|
if window_size is None:
|
||||||
real_window_size = (-1, -1)
|
real_window_size = (-1, -1)
|
||||||
else:
|
else:
|
||||||
assert len(window_size) == 2
|
assert len(window_size) == 2
|
||||||
real_window_size = (window_size[0], window_size[1])
|
real_window_size = (window_size[0], window_size[1])
|
||||||
return ipex.llm.modules.PagedAttention.flash_attn_varlen_func(
|
|
||||||
out,
|
if block_table is None:
|
||||||
q.contiguous(),
|
assert cu_seqlens_k is not None, (
|
||||||
k,
|
"cu_seqlens_k can't be None when calling varlen_attention."
|
||||||
v,
|
)
|
||||||
cu_seqlens_q,
|
if softmax_scale is None:
|
||||||
seqused_k,
|
softmax_scale = q.shape[-1] ** (-0.5)
|
||||||
max_seqlen_q,
|
ipex_ops.varlen_attention(
|
||||||
max_seqlen_k,
|
q.contiguous(),
|
||||||
softmax_scale,
|
k.contiguous(),
|
||||||
causal,
|
v.contiguous(),
|
||||||
block_table,
|
out,
|
||||||
alibi_slopes,
|
cu_seqlens_q,
|
||||||
softcap=softcap,
|
cu_seqlens_k,
|
||||||
window_size_left=real_window_size[0],
|
None,
|
||||||
window_size_right=real_window_size[1],
|
max_seqlen_q,
|
||||||
k_scale=1.0,
|
max_seqlen_k,
|
||||||
v_scale=1.0,
|
0.0,
|
||||||
)
|
softmax_scale,
|
||||||
|
False,
|
||||||
|
causal,
|
||||||
|
False,
|
||||||
|
None,
|
||||||
|
real_window_size[0],
|
||||||
|
real_window_size[1],
|
||||||
|
-1,
|
||||||
|
)
|
||||||
|
return out
|
||||||
|
else:
|
||||||
|
return ipex.llm.modules.PagedAttention.flash_attn_varlen_func(
|
||||||
|
out,
|
||||||
|
q.contiguous(),
|
||||||
|
k,
|
||||||
|
v,
|
||||||
|
cu_seqlens_q,
|
||||||
|
seqused_k,
|
||||||
|
max_seqlen_q,
|
||||||
|
max_seqlen_k,
|
||||||
|
softmax_scale,
|
||||||
|
causal,
|
||||||
|
block_table,
|
||||||
|
alibi_slopes,
|
||||||
|
sink=s_aux,
|
||||||
|
softcap=softcap,
|
||||||
|
window_size_left=real_window_size[0],
|
||||||
|
window_size_right=real_window_size[1],
|
||||||
|
k_scale=1.0,
|
||||||
|
v_scale=1.0,
|
||||||
|
)
|
||||||
|
|
||||||
@staticmethod
|
@staticmethod
|
||||||
def get_scheduler_metadata(
|
def get_scheduler_metadata(
|
||||||
|
|||||||
@ -123,6 +123,11 @@ def maybe_get_vit_flash_attn_backend(
|
|||||||
):
|
):
|
||||||
attn_backend = _Backend.FLASH_ATTN
|
attn_backend = _Backend.FLASH_ATTN
|
||||||
use_upstream_fa = True
|
use_upstream_fa = True
|
||||||
|
elif current_platform.is_xpu():
|
||||||
|
assert attn_backend == _Backend.FLASH_ATTN, (
|
||||||
|
"XPU platform only supports FLASH_ATTN as vision attention backend."
|
||||||
|
)
|
||||||
|
use_upstream_fa = False
|
||||||
else:
|
else:
|
||||||
return _Backend.TORCH_SDPA, None
|
return _Backend.TORCH_SDPA, None
|
||||||
|
|
||||||
@ -133,7 +138,7 @@ def maybe_get_vit_flash_attn_backend(
|
|||||||
if use_upstream_fa:
|
if use_upstream_fa:
|
||||||
from flash_attn import flash_attn_varlen_func
|
from flash_attn import flash_attn_varlen_func
|
||||||
else:
|
else:
|
||||||
from vllm.vllm_flash_attn import flash_attn_varlen_func
|
from vllm.attention.utils.fa_utils import flash_attn_varlen_func
|
||||||
else:
|
else:
|
||||||
flash_attn_varlen_func = None
|
flash_attn_varlen_func = None
|
||||||
|
|
||||||
@ -521,22 +526,18 @@ class MultiHeadAttention(nn.Module):
|
|||||||
# If vllm native fa is selected, we use it directly.
|
# If vllm native fa is selected, we use it directly.
|
||||||
use_upstream_fa = False
|
use_upstream_fa = False
|
||||||
|
|
||||||
if current_platform.is_xpu():
|
self.attn_backend = (
|
||||||
# currently, only torch_sdpa is supported on xpu
|
backend
|
||||||
self.attn_backend = _Backend.TORCH_SDPA
|
if backend
|
||||||
else:
|
in {
|
||||||
self.attn_backend = (
|
_Backend.TORCH_SDPA,
|
||||||
backend
|
_Backend.XFORMERS,
|
||||||
if backend
|
_Backend.PALLAS,
|
||||||
in {
|
_Backend.ROCM_AITER_FA,
|
||||||
_Backend.TORCH_SDPA,
|
_Backend.FLASH_ATTN,
|
||||||
_Backend.XFORMERS,
|
}
|
||||||
_Backend.PALLAS,
|
else _Backend.TORCH_SDPA
|
||||||
_Backend.ROCM_AITER_FA,
|
)
|
||||||
_Backend.FLASH_ATTN,
|
|
||||||
}
|
|
||||||
else _Backend.TORCH_SDPA
|
|
||||||
)
|
|
||||||
|
|
||||||
self.attn_backend, self._flash_attn_varlen_func = (
|
self.attn_backend, self._flash_attn_varlen_func = (
|
||||||
maybe_get_vit_flash_attn_backend(
|
maybe_get_vit_flash_attn_backend(
|
||||||
|
|||||||
@ -70,7 +70,7 @@ def flash_attn_maxseqlen_wrapper(
|
|||||||
if use_upstream_fa:
|
if use_upstream_fa:
|
||||||
from flash_attn import flash_attn_varlen_func
|
from flash_attn import flash_attn_varlen_func
|
||||||
else:
|
else:
|
||||||
from vllm.vllm_flash_attn import flash_attn_varlen_func
|
from vllm.attention.utils.fa_utils import flash_attn_varlen_func
|
||||||
q, k, v = (einops.rearrange(x, "b s ... -> (b s) ...") for x in [q, k, v])
|
q, k, v = (einops.rearrange(x, "b s ... -> (b s) ...") for x in [q, k, v])
|
||||||
output = flash_attn_varlen_func(
|
output = flash_attn_varlen_func(
|
||||||
q,
|
q,
|
||||||
|
|||||||
@ -364,6 +364,8 @@ class Qwen2_5_VisionAttention(nn.Module):
|
|||||||
|
|
||||||
if current_platform.is_rocm() and self.attn_backend == _Backend.FLASH_ATTN:
|
if current_platform.is_rocm() and self.attn_backend == _Backend.FLASH_ATTN:
|
||||||
self.use_upstream_fa = True
|
self.use_upstream_fa = True
|
||||||
|
if current_platform.is_xpu():
|
||||||
|
self.use_upstream_fa = False
|
||||||
self.is_flash_attn_backend = self.attn_backend in {
|
self.is_flash_attn_backend = self.attn_backend in {
|
||||||
_Backend.FLASH_ATTN,
|
_Backend.FLASH_ATTN,
|
||||||
_Backend.ROCM_AITER_FA,
|
_Backend.ROCM_AITER_FA,
|
||||||
@ -856,10 +858,7 @@ class Qwen2_5_VisionTransformer(nn.Module):
|
|||||||
) -> tuple[torch.Tensor, torch.Tensor]:
|
) -> tuple[torch.Tensor, torch.Tensor]:
|
||||||
max_seqlen = torch.zeros([], device=cu_seqlens.device)
|
max_seqlen = torch.zeros([], device=cu_seqlens.device)
|
||||||
seqlens = torch.zeros(1, device=cu_seqlens.device)
|
seqlens = torch.zeros(1, device=cu_seqlens.device)
|
||||||
if (
|
if self.attn_backend in {_Backend.FLASH_ATTN, _Backend.ROCM_AITER_FA}:
|
||||||
self.attn_backend == _Backend.FLASH_ATTN
|
|
||||||
or self.attn_backend == _Backend.ROCM_AITER_FA
|
|
||||||
):
|
|
||||||
max_seqlen = (cu_seqlens[1:] - cu_seqlens[:-1]).max()
|
max_seqlen = (cu_seqlens[1:] - cu_seqlens[:-1]).max()
|
||||||
elif self.attn_backend == _Backend.XFORMERS:
|
elif self.attn_backend == _Backend.XFORMERS:
|
||||||
seqlens = cu_seqlens[1:] - cu_seqlens[:-1]
|
seqlens = cu_seqlens[1:] - cu_seqlens[:-1]
|
||||||
|
|||||||
@ -789,10 +789,7 @@ class Qwen2VisionTransformer(nn.Module):
|
|||||||
self, cu_seqlens: torch.Tensor
|
self, cu_seqlens: torch.Tensor
|
||||||
) -> tuple[int | None, list[int] | None]:
|
) -> tuple[int | None, list[int] | None]:
|
||||||
max_seqlen, seqlens = None, None
|
max_seqlen, seqlens = None, None
|
||||||
if (
|
if self.attn_backend in {_Backend.FLASH_ATTN, _Backend.ROCM_AITER_FA}:
|
||||||
self.attn_backend == _Backend.FLASH_ATTN
|
|
||||||
or self.attn_backend == _Backend.ROCM_AITER_FA
|
|
||||||
):
|
|
||||||
max_seqlen = (cu_seqlens[1:] - cu_seqlens[:-1]).max().item()
|
max_seqlen = (cu_seqlens[1:] - cu_seqlens[:-1]).max().item()
|
||||||
elif self.attn_backend == _Backend.XFORMERS:
|
elif self.attn_backend == _Backend.XFORMERS:
|
||||||
seqlens = (cu_seqlens[1:] - cu_seqlens[:-1]).tolist()
|
seqlens = (cu_seqlens[1:] - cu_seqlens[:-1]).tolist()
|
||||||
|
|||||||
@ -115,6 +115,12 @@ class XPUPlatform(Platform):
|
|||||||
device_props = torch.xpu.get_device_properties(device_id)
|
device_props = torch.xpu.get_device_properties(device_id)
|
||||||
return device_props.total_memory
|
return device_props.total_memory
|
||||||
|
|
||||||
|
@classmethod
|
||||||
|
def get_vit_attn_backend(cls, head_size: int, dtype: torch.dtype) -> _Backend:
|
||||||
|
from vllm.attention.backends.registry import _Backend
|
||||||
|
|
||||||
|
return _Backend.FLASH_ATTN
|
||||||
|
|
||||||
@classmethod
|
@classmethod
|
||||||
def inference_mode(cls):
|
def inference_mode(cls):
|
||||||
return torch.no_grad()
|
return torch.no_grad()
|
||||||
|
|||||||
Loading…
x
Reference in New Issue
Block a user