diff --git a/tests/kernels/attention/test_cascade_flash_attn.py b/tests/kernels/attention/test_cascade_flash_attn.py index 20f573821b25f..d86041d71febd 100755 --- a/tests/kernels/attention/test_cascade_flash_attn.py +++ b/tests/kernels/attention/test_cascade_flash_attn.py @@ -7,11 +7,19 @@ import torch from vllm.platforms import current_platform from vllm.v1.attention.backends.flash_attn import cascade_attention, merge_attn_states -from vllm.vllm_flash_attn import ( - fa_version_unsupported_reason, - flash_attn_varlen_func, - is_fa_version_supported, -) + +try: + from vllm.vllm_flash_attn import ( + fa_version_unsupported_reason, + flash_attn_varlen_func, + is_fa_version_supported, + ) +except ImportError: + if current_platform.is_rocm(): + pytest.skip( + "vllm_flash_attn is not supported for vLLM on ROCm.", + allow_module_level=True, + ) NUM_HEADS = [(4, 4), (8, 2), (16, 2)] HEAD_SIZES = [128, 192, 256] diff --git a/tests/kernels/attention/test_flash_attn.py b/tests/kernels/attention/test_flash_attn.py index 26b8c77ab482f..bbd5df5419f80 100644 --- a/tests/kernels/attention/test_flash_attn.py +++ b/tests/kernels/attention/test_flash_attn.py @@ -6,11 +6,20 @@ import pytest import torch from vllm.platforms import current_platform -from vllm.vllm_flash_attn import ( - fa_version_unsupported_reason, - flash_attn_varlen_func, - is_fa_version_supported, -) + +try: + from vllm.vllm_flash_attn import ( + fa_version_unsupported_reason, + flash_attn_varlen_func, + is_fa_version_supported, + ) +except ImportError: + if current_platform.is_rocm(): + pytest.skip( + "vllm_flash_attn is not supported for vLLM on ROCm.", + allow_module_level=True, + ) + NUM_HEADS = [(4, 4), (8, 2)] HEAD_SIZES = [40, 72, 80, 128, 256] diff --git a/tests/kernels/attention/test_flashinfer.py b/tests/kernels/attention/test_flashinfer.py index 82ec2ef14e56c..eedeec33e0d45 100644 --- a/tests/kernels/attention/test_flashinfer.py +++ b/tests/kernels/attention/test_flashinfer.py @@ -2,12 +2,20 @@ # SPDX-FileCopyrightText: Copyright contributors to the vLLM project -import flashinfer import pytest -import torch from vllm.platforms import current_platform +try: + import flashinfer +except ImportError: + if current_platform.is_rocm(): + pytest.skip( + "flashinfer is not supported for vLLM on ROCm.", allow_module_level=True + ) + +import torch + NUM_HEADS = [(32, 8), (6, 1)] HEAD_SIZES = [128, 256] BLOCK_SIZES = [16, 32] diff --git a/tests/kernels/attention/test_flashinfer_mla_decode.py b/tests/kernels/attention/test_flashinfer_mla_decode.py index 0350136677c6b..d183f67d3919e 100644 --- a/tests/kernels/attention/test_flashinfer_mla_decode.py +++ b/tests/kernels/attention/test_flashinfer_mla_decode.py @@ -3,7 +3,6 @@ import pytest import torch import torch.nn.functional as F -from flashinfer.decode import trtllm_batch_decode_with_kv_cache_mla from torch import Tensor from vllm.platforms import current_platform @@ -15,6 +14,8 @@ if not current_platform.has_device_capability(100): reason="FlashInfer MLA Requires compute capability of 10 or above.", allow_module_level=True, ) +else: + from flashinfer.decode import trtllm_batch_decode_with_kv_cache_mla def ref_mla( diff --git a/tests/kernels/attention/test_flashinfer_trtllm_attention.py b/tests/kernels/attention/test_flashinfer_trtllm_attention.py index 693b849ebc5d7..98ea40608b468 100644 --- a/tests/kernels/attention/test_flashinfer_trtllm_attention.py +++ b/tests/kernels/attention/test_flashinfer_trtllm_attention.py @@ -1,7 +1,6 @@ # SPDX-License-Identifier: Apache-2.0 # SPDX-FileCopyrightText: Copyright contributors to the vLLM project -import flashinfer import pytest import torch @@ -16,6 +15,8 @@ if not current_platform.is_device_capability(100): pytest.skip( "This TRTLLM kernel requires NVIDIA Blackwell.", allow_module_level=True ) +else: + import flashinfer FLOAT32_BYTES = torch.finfo(torch.float).bits // 8 FP8_DTYPE = current_platform.fp8_dtype() diff --git a/tests/kernels/moe/test_flashinfer.py b/tests/kernels/moe/test_flashinfer.py index 218df4a2632c3..638741e91619b 100644 --- a/tests/kernels/moe/test_flashinfer.py +++ b/tests/kernels/moe/test_flashinfer.py @@ -22,7 +22,14 @@ from vllm.model_executor.layers.quantization.utils.flashinfer_utils import ( from vllm.model_executor.layers.quantization.utils.fp8_utils import input_to_float8 from vllm.model_executor.models.llama4 import Llama4MoE from vllm.platforms import current_platform -from vllm.utils.flashinfer import has_flashinfer_cutlass_fused_moe + +try: + from vllm.utils.flashinfer import has_flashinfer_cutlass_fused_moe +except ImportError: + if current_platform.is_rocm(): + pytest.skip( + "flashinfer not supported for vLLM on ROCm", allow_module_level=True + ) if not has_flashinfer_cutlass_fused_moe() or not current_platform.has_device_capability( 90