diff --git a/vllm/attention/utils/fa_utils.py b/vllm/attention/utils/fa_utils.py index e38c88f4838d1..1daa79762471f 100644 --- a/vllm/attention/utils/fa_utils.py +++ b/vllm/attention/utils/fa_utils.py @@ -31,7 +31,7 @@ def get_flash_attn_version(requires_alibi: bool = False) -> int | None: # import here to avoid circular dependencies from vllm.platforms import current_platform - if current_platform.is_xpu(): + if current_platform.is_xpu() or current_platform.is_rocm(): return 2 try: from vllm.vllm_flash_attn.flash_attn_interface import (