[Bugfix] Add device assertion to TorchSDPA (#5402)

This commit is contained in:
Li, Jiang 2024-06-13 03:58:53 +08:00 committed by GitHub
parent 1a8bfd92d5
commit c3c2903e72
No known key found for this signature in database
GPG Key ID: B5690EEEBB952194

View File

@ -58,6 +58,9 @@ def get_attn_backend(
ROCmFlashAttentionBackend)
return ROCmFlashAttentionBackend
elif backend == _Backend.TORCH_SDPA:
# TODO: make XPU backend available here.
assert is_cpu(), RuntimeError(
"Torch SDPA backend is only used for the CPU device.")
logger.info("Using Torch SDPA backend.")
from vllm.attention.backends.torch_sdpa import TorchSDPABackend
return TorchSDPABackend