mirror of
https://git.datalinker.icu/vllm-project/vllm.git
synced 2025-12-16 12:35:01 +08:00
[Bugfix] Add device assertion to TorchSDPA (#5402)
This commit is contained in:
parent
1a8bfd92d5
commit
c3c2903e72
@ -58,6 +58,9 @@ def get_attn_backend(
|
||||
ROCmFlashAttentionBackend)
|
||||
return ROCmFlashAttentionBackend
|
||||
elif backend == _Backend.TORCH_SDPA:
|
||||
# TODO: make XPU backend available here.
|
||||
assert is_cpu(), RuntimeError(
|
||||
"Torch SDPA backend is only used for the CPU device.")
|
||||
logger.info("Using Torch SDPA backend.")
|
||||
from vllm.attention.backends.torch_sdpa import TorchSDPABackend
|
||||
return TorchSDPABackend
|
||||
|
||||
Loading…
x
Reference in New Issue
Block a user