mirror of
https://git.datalinker.icu/vllm-project/vllm.git
synced 2025-12-09 05:34:55 +08:00
[Test] Remove old non-varlen FA2 test (#28420)
Signed-off-by: Matthew Bonanni <mbonanni@redhat.com>
This commit is contained in:
parent
a5a790eea6
commit
0bf29fadf5
@ -9,7 +9,6 @@ from vllm.platforms import current_platform
|
||||
from vllm.vllm_flash_attn import (
|
||||
fa_version_unsupported_reason,
|
||||
flash_attn_varlen_func,
|
||||
flash_attn_with_kvcache,
|
||||
is_fa_version_supported,
|
||||
)
|
||||
|
||||
@ -83,124 +82,6 @@ def ref_paged_attn(
|
||||
return torch.cat(outputs, dim=0)
|
||||
|
||||
|
||||
@pytest.mark.parametrize("use_out", [True, False])
|
||||
@pytest.mark.parametrize("kv_lens", [[1328, 18, 463], [1, 54, 293, 70]])
|
||||
@pytest.mark.parametrize("num_heads", NUM_HEADS)
|
||||
@pytest.mark.parametrize("head_size", HEAD_SIZES)
|
||||
@pytest.mark.parametrize("block_size", BLOCK_SIZES)
|
||||
@pytest.mark.parametrize("dtype", DTYPES)
|
||||
@pytest.mark.parametrize("soft_cap", SOFT_CAPS)
|
||||
@pytest.mark.parametrize("num_blocks", NUM_BLOCKS)
|
||||
@pytest.mark.parametrize("sliding_window", SLIDING_WINDOWS)
|
||||
@pytest.mark.parametrize("fa_version", [2, 3])
|
||||
@pytest.mark.parametrize("q_dtype", QDTYPES)
|
||||
@torch.inference_mode()
|
||||
def test_flash_attn_with_paged_kv(
|
||||
use_out: bool,
|
||||
kv_lens: list[int],
|
||||
num_heads: tuple[int, int],
|
||||
head_size: int,
|
||||
dtype: torch.dtype,
|
||||
block_size: int,
|
||||
soft_cap: float | None,
|
||||
num_blocks: int,
|
||||
sliding_window: int | None,
|
||||
fa_version: int,
|
||||
q_dtype: torch.dtype | None,
|
||||
) -> None:
|
||||
torch.set_default_device("cuda")
|
||||
if not is_fa_version_supported(fa_version):
|
||||
pytest.skip(
|
||||
f"Flash attention version {fa_version} not supported due "
|
||||
f'to: "{fa_version_unsupported_reason(fa_version)}"'
|
||||
)
|
||||
if q_dtype is not None and (dtype != torch.bfloat16 or fa_version == 2):
|
||||
pytest.skip(
|
||||
"Flash attention with quantized inputs is only "
|
||||
"supported on version 3 with bfloat16 base type"
|
||||
)
|
||||
|
||||
current_platform.seed_everything(0)
|
||||
num_seqs = len(kv_lens)
|
||||
num_query_heads = num_heads[0]
|
||||
num_kv_heads = num_heads[1]
|
||||
assert num_query_heads % num_kv_heads == 0
|
||||
max_kv_len = max(kv_lens)
|
||||
scale = head_size**-0.5
|
||||
window_size = (sliding_window - 1, 0) if sliding_window is not None else (-1, -1)
|
||||
|
||||
query = torch.randn(num_seqs, num_query_heads, head_size, dtype=dtype)
|
||||
key_cache = torch.randn(
|
||||
num_blocks, block_size, num_kv_heads, head_size, dtype=dtype
|
||||
)
|
||||
value_cache = torch.randn_like(key_cache)
|
||||
kv_lens_tensor = torch.tensor(kv_lens, dtype=torch.int32)
|
||||
|
||||
max_num_blocks_per_seq = (max_kv_len + block_size - 1) // block_size
|
||||
block_tables = torch.randint(
|
||||
0, num_blocks, (num_seqs, max_num_blocks_per_seq), dtype=torch.int32
|
||||
)
|
||||
|
||||
q = query.unsqueeze(1)
|
||||
out = torch.empty_like(q) if use_out else None
|
||||
|
||||
maybe_quantized_query = q
|
||||
maybe_quantized_key_cache = key_cache
|
||||
maybe_quantized_value_cache = value_cache
|
||||
q_descale = None
|
||||
k_descale = None
|
||||
v_descale = None
|
||||
if q_dtype is not None:
|
||||
# QKV are drawn from N(0, 1): no need for a fp8 scaling factor
|
||||
maybe_quantized_query = q.to(q_dtype)
|
||||
maybe_quantized_key_cache = key_cache.to(q_dtype)
|
||||
maybe_quantized_value_cache = value_cache.to(q_dtype)
|
||||
|
||||
scale_shape = (num_seqs, num_kv_heads)
|
||||
q_descale = torch.ones(scale_shape, dtype=torch.float32)
|
||||
k_descale = torch.ones(scale_shape, dtype=torch.float32)
|
||||
v_descale = torch.ones(scale_shape, dtype=torch.float32)
|
||||
|
||||
output = flash_attn_with_kvcache(
|
||||
q=maybe_quantized_query,
|
||||
k_cache=maybe_quantized_key_cache,
|
||||
v_cache=maybe_quantized_value_cache,
|
||||
out=out,
|
||||
softmax_scale=scale,
|
||||
causal=True,
|
||||
block_table=block_tables,
|
||||
cache_seqlens=kv_lens_tensor,
|
||||
softcap=soft_cap if soft_cap is not None else 0,
|
||||
window_size=window_size,
|
||||
fa_version=fa_version,
|
||||
q_descale=q_descale,
|
||||
k_descale=k_descale,
|
||||
v_descale=v_descale,
|
||||
)
|
||||
output = output if not use_out else out
|
||||
output = output.squeeze(1)
|
||||
|
||||
atol, rtol = 1.5e-2, 1e-2
|
||||
if q_dtype is not None:
|
||||
atol, rtol = 1.5e-1, 1.5e-1
|
||||
|
||||
ref_output = ref_paged_attn(
|
||||
query=query,
|
||||
key_cache=key_cache,
|
||||
value_cache=value_cache,
|
||||
query_lens=[1] * num_seqs,
|
||||
kv_lens=kv_lens,
|
||||
block_tables=block_tables,
|
||||
scale=scale,
|
||||
soft_cap=soft_cap,
|
||||
sliding_window=sliding_window,
|
||||
)
|
||||
(
|
||||
torch.testing.assert_close(output, ref_output, atol=atol, rtol=rtol),
|
||||
f"{torch.max(torch.abs(output - ref_output))}",
|
||||
)
|
||||
|
||||
|
||||
@pytest.mark.parametrize("use_out", [True, False])
|
||||
@pytest.mark.parametrize(
|
||||
"seq_lens", [[(1, 1328), (5, 18), (129, 463)], [(1, 523), (1, 37), (1, 2011)]]
|
||||
|
||||
Loading…
x
Reference in New Issue
Block a user