diff --git a/vllm/attention/ops/triton_flash_attention.py b/vllm/attention/ops/triton_flash_attention.py index a26e713b1c624..49070e4c7ae6a 100644 --- a/vllm/attention/ops/triton_flash_attention.py +++ b/vllm/attention/ops/triton_flash_attention.py @@ -25,9 +25,14 @@ Not currently supported: import torch from vllm.platforms import current_platform -from vllm.platforms.rocm import on_gfx1x from vllm.triton_utils import tl, triton +# Avoid misleading ROCm warning. +if current_platform.is_rocm(): + from vllm.platforms.rocm import on_gfx1x +else: + on_gfx1x = lambda *args, **kwargs: False + torch_dtype: tl.constexpr = torch.float16