[Test] Remove old non-varlen FA2 test (#28420)

Signed-off-by: Matthew Bonanni <mbonanni@redhat.com>
This commit is contained in:
Matthew Bonanni 2025-11-10 17:57:41 -06:00 committed by GitHub
parent a5a790eea6
commit 0bf29fadf5
No known key found for this signature in database
GPG Key ID: B5690EEEBB952194

View File

@ -9,7 +9,6 @@ from vllm.platforms import current_platform
from vllm.vllm_flash_attn import (
fa_version_unsupported_reason,
flash_attn_varlen_func,
flash_attn_with_kvcache,
is_fa_version_supported,
)
@ -83,124 +82,6 @@ def ref_paged_attn(
return torch.cat(outputs, dim=0)
@pytest.mark.parametrize("use_out", [True, False])
@pytest.mark.parametrize("kv_lens", [[1328, 18, 463], [1, 54, 293, 70]])
@pytest.mark.parametrize("num_heads", NUM_HEADS)
@pytest.mark.parametrize("head_size", HEAD_SIZES)
@pytest.mark.parametrize("block_size", BLOCK_SIZES)
@pytest.mark.parametrize("dtype", DTYPES)
@pytest.mark.parametrize("soft_cap", SOFT_CAPS)
@pytest.mark.parametrize("num_blocks", NUM_BLOCKS)
@pytest.mark.parametrize("sliding_window", SLIDING_WINDOWS)
@pytest.mark.parametrize("fa_version", [2, 3])
@pytest.mark.parametrize("q_dtype", QDTYPES)
@torch.inference_mode()
def test_flash_attn_with_paged_kv(
use_out: bool,
kv_lens: list[int],
num_heads: tuple[int, int],
head_size: int,
dtype: torch.dtype,
block_size: int,
soft_cap: float | None,
num_blocks: int,
sliding_window: int | None,
fa_version: int,
q_dtype: torch.dtype | None,
) -> None:
torch.set_default_device("cuda")
if not is_fa_version_supported(fa_version):
pytest.skip(
f"Flash attention version {fa_version} not supported due "
f'to: "{fa_version_unsupported_reason(fa_version)}"'
)
if q_dtype is not None and (dtype != torch.bfloat16 or fa_version == 2):
pytest.skip(
"Flash attention with quantized inputs is only "
"supported on version 3 with bfloat16 base type"
)
current_platform.seed_everything(0)
num_seqs = len(kv_lens)
num_query_heads = num_heads[0]
num_kv_heads = num_heads[1]
assert num_query_heads % num_kv_heads == 0
max_kv_len = max(kv_lens)
scale = head_size**-0.5
window_size = (sliding_window - 1, 0) if sliding_window is not None else (-1, -1)
query = torch.randn(num_seqs, num_query_heads, head_size, dtype=dtype)
key_cache = torch.randn(
num_blocks, block_size, num_kv_heads, head_size, dtype=dtype
)
value_cache = torch.randn_like(key_cache)
kv_lens_tensor = torch.tensor(kv_lens, dtype=torch.int32)
max_num_blocks_per_seq = (max_kv_len + block_size - 1) // block_size
block_tables = torch.randint(
0, num_blocks, (num_seqs, max_num_blocks_per_seq), dtype=torch.int32
)
q = query.unsqueeze(1)
out = torch.empty_like(q) if use_out else None
maybe_quantized_query = q
maybe_quantized_key_cache = key_cache
maybe_quantized_value_cache = value_cache
q_descale = None
k_descale = None
v_descale = None
if q_dtype is not None:
# QKV are drawn from N(0, 1): no need for a fp8 scaling factor
maybe_quantized_query = q.to(q_dtype)
maybe_quantized_key_cache = key_cache.to(q_dtype)
maybe_quantized_value_cache = value_cache.to(q_dtype)
scale_shape = (num_seqs, num_kv_heads)
q_descale = torch.ones(scale_shape, dtype=torch.float32)
k_descale = torch.ones(scale_shape, dtype=torch.float32)
v_descale = torch.ones(scale_shape, dtype=torch.float32)
output = flash_attn_with_kvcache(
q=maybe_quantized_query,
k_cache=maybe_quantized_key_cache,
v_cache=maybe_quantized_value_cache,
out=out,
softmax_scale=scale,
causal=True,
block_table=block_tables,
cache_seqlens=kv_lens_tensor,
softcap=soft_cap if soft_cap is not None else 0,
window_size=window_size,
fa_version=fa_version,
q_descale=q_descale,
k_descale=k_descale,
v_descale=v_descale,
)
output = output if not use_out else out
output = output.squeeze(1)
atol, rtol = 1.5e-2, 1e-2
if q_dtype is not None:
atol, rtol = 1.5e-1, 1.5e-1
ref_output = ref_paged_attn(
query=query,
key_cache=key_cache,
value_cache=value_cache,
query_lens=[1] * num_seqs,
kv_lens=kv_lens,
block_tables=block_tables,
scale=scale,
soft_cap=soft_cap,
sliding_window=sliding_window,
)
(
torch.testing.assert_close(output, ref_output, atol=atol, rtol=rtol),
f"{torch.max(torch.abs(output - ref_output))}",
)
@pytest.mark.parametrize("use_out", [True, False])
@pytest.mark.parametrize(
"seq_lens", [[(1, 1328), (5, 18), (129, 463)], [(1, 523), (1, 37), (1, 2011)]]