mirror of
https://git.datalinker.icu/vllm-project/vllm.git
synced 2025-12-10 02:25:01 +08:00
[XPU] Whisper model support on XPU Platform (#25123)
Signed-off-by: chzhang <chaojun.zhang@intel.com>
This commit is contained in:
parent
bec060fd99
commit
3bc18127ff
@ -391,8 +391,8 @@ class MultiHeadAttention(nn.Module):
|
|||||||
backend = _Backend.FLASH_ATTN
|
backend = _Backend.FLASH_ATTN
|
||||||
use_upstream_fa = True
|
use_upstream_fa = True
|
||||||
|
|
||||||
if current_platform.is_rocm():
|
if current_platform.is_rocm() or current_platform.is_xpu():
|
||||||
# currently, only torch_sdpa is supported on rocm
|
# currently, only torch_sdpa is supported on rocm/xpu
|
||||||
self.attn_backend = _Backend.TORCH_SDPA
|
self.attn_backend = _Backend.TORCH_SDPA
|
||||||
else:
|
else:
|
||||||
|
|
||||||
|
|||||||
@ -282,7 +282,7 @@ def bind_kv_cache(
|
|||||||
# TODO - analyze where runner_kv_caches is used and the right
|
# TODO - analyze where runner_kv_caches is used and the right
|
||||||
# way to ensure it properly reflects multiple attention layers
|
# way to ensure it properly reflects multiple attention layers
|
||||||
# in the same decoder block.
|
# in the same decoder block.
|
||||||
if current_platform.is_cuda():
|
if current_platform.is_cuda() or current_platform.is_xpu():
|
||||||
# We know that the GPU runner is not impacted by this
|
# We know that the GPU runner is not impacted by this
|
||||||
# case. Some test code depends on runner_kv_caches, but
|
# case. Some test code depends on runner_kv_caches, but
|
||||||
# not in a way that's impacted by ignoring this.
|
# not in a way that's impacted by ignoring this.
|
||||||
|
|||||||
Loading…
x
Reference in New Issue
Block a user