mirror of
https://git.datalinker.icu/vllm-project/vllm.git
synced 2025-12-22 04:45:01 +08:00
[Bugfix] Add device assertion to TorchSDPA (#5402)
This commit is contained in:
parent
1a8bfd92d5
commit
c3c2903e72
@ -58,6 +58,9 @@ def get_attn_backend(
|
|||||||
ROCmFlashAttentionBackend)
|
ROCmFlashAttentionBackend)
|
||||||
return ROCmFlashAttentionBackend
|
return ROCmFlashAttentionBackend
|
||||||
elif backend == _Backend.TORCH_SDPA:
|
elif backend == _Backend.TORCH_SDPA:
|
||||||
|
# TODO: make XPU backend available here.
|
||||||
|
assert is_cpu(), RuntimeError(
|
||||||
|
"Torch SDPA backend is only used for the CPU device.")
|
||||||
logger.info("Using Torch SDPA backend.")
|
logger.info("Using Torch SDPA backend.")
|
||||||
from vllm.attention.backends.torch_sdpa import TorchSDPABackend
|
from vllm.attention.backends.torch_sdpa import TorchSDPABackend
|
||||||
return TorchSDPABackend
|
return TorchSDPABackend
|
||||||
|
|||||||
Loading…
x
Reference in New Issue
Block a user