diff --git a/tests/kernels/attention/test_aiter_flash_attn.py b/tests/kernels/attention/test_aiter_flash_attn.py index 1dec46e33f22e..8f58c470d217a 100644 --- a/tests/kernels/attention/test_aiter_flash_attn.py +++ b/tests/kernels/attention/test_aiter_flash_attn.py @@ -6,6 +6,7 @@ import pytest import torch import vllm.v1.attention.backends.rocm_aiter_fa # noqa: F401 +from vllm.attention.utils.fa_utils import is_flash_attn_varlen_func_available from vllm.platforms import current_platform NUM_HEADS = [(4, 4), (8, 2)] @@ -100,6 +101,8 @@ def test_varlen_with_paged_kv( num_blocks: int, q_dtype: torch.dtype | None, ) -> None: + if not is_flash_attn_varlen_func_available(): + pytest.skip("flash_attn_varlen_func required to run this test.") torch.set_default_device("cuda") current_platform.seed_everything(0) num_seqs = len(seq_lens)