fix(rocm): add early return in get_flash_attn_version for ROCm

Prevents spurious "libcudart.so.12 not found" errors by skipping
the CUDA-specific vllm_flash_attn import on ROCm platform.

Signed-off-by: rabi <ramishra@redhat.com>
This commit is contained in:
rabi 2025-12-24 19:23:14 +05:30
parent 7adeb4bfa8
commit 3b1a3cae0e

View File

@ -31,7 +31,7 @@ def get_flash_attn_version(requires_alibi: bool = False) -> int | None:
# import here to avoid circular dependencies
from vllm.platforms import current_platform
if current_platform.is_xpu():
if current_platform.is_xpu() or current_platform.is_rocm():
return 2
try:
from vllm.vllm_flash_attn.flash_attn_interface import (