mirror of
https://git.datalinker.icu/vllm-project/vllm.git
synced 2026-04-22 19:47:08 +08:00
fix(rocm): add early return in get_flash_attn_version for ROCm
Prevents spurious "libcudart.so.12 not found" errors by skipping the CUDA-specific vllm_flash_attn import on ROCm platform. Signed-off-by: rabi <ramishra@redhat.com>
This commit is contained in:
parent
7adeb4bfa8
commit
3b1a3cae0e
@ -31,7 +31,7 @@ def get_flash_attn_version(requires_alibi: bool = False) -> int | None:
|
||||
# import here to avoid circular dependencies
|
||||
from vllm.platforms import current_platform
|
||||
|
||||
if current_platform.is_xpu():
|
||||
if current_platform.is_xpu() or current_platform.is_rocm():
|
||||
return 2
|
||||
try:
|
||||
from vllm.vllm_flash_attn.flash_attn_interface import (
|
||||
|
||||
Loading…
x
Reference in New Issue
Block a user