diff --git a/tests/v1/spec_decode/test_tree_attention.py b/tests/v1/spec_decode/test_tree_attention.py index 6958d62dc7e90..a4ee53008ce82 100644 --- a/tests/v1/spec_decode/test_tree_attention.py +++ b/tests/v1/spec_decode/test_tree_attention.py @@ -3,6 +3,7 @@ import math +import pytest import torch from tests.v1.attention.utils import ( @@ -11,9 +12,16 @@ from tests.v1.attention.utils import ( try_get_attention_backend, ) from vllm.attention.backends.registry import AttentionBackendEnum +from vllm.attention.utils.fa_utils import is_flash_attn_varlen_func_available from vllm.config import ParallelConfig, SpeculativeConfig from vllm.v1.attention.backends.utils import CommonAttentionMetadata +if not is_flash_attn_varlen_func_available(): + pytest.skip( + "This test requires flash_attn_varlen_func, but it's not available.", + allow_module_level=True, + ) + class MockAttentionLayer(torch.nn.Module): _q_scale = torch.tensor(1.0, dtype=torch.float32, device="cuda")