[Misc] Add sliding window to flashinfer test (#21282)

Signed-off-by: Woosuk Kwon <woosuk.kwon@berkeley.edu>
This commit is contained in:
Woosuk Kwon 2025-07-21 08:37:49 -07:00 committed by GitHub
parent 6b46c4b653
commit 6dda13c86b
No known key found for this signature in database
GPG Key ID: B5690EEEBB952194

View File

@ -77,6 +77,7 @@ def ref_paged_attn(
@pytest.mark.parametrize("block_size", BLOCK_SIZES)
@pytest.mark.parametrize("dtype", DTYPES)
@pytest.mark.parametrize("soft_cap", [None, 30.0, 50.0])
@pytest.mark.parametrize("sliding_window", [None, 64])
@torch.inference_mode
def test_flashinfer_decode_with_paged_kv(
kv_lens: list[int],
@ -85,6 +86,7 @@ def test_flashinfer_decode_with_paged_kv(
dtype: torch.dtype,
block_size: int,
soft_cap: Optional[float],
sliding_window: Optional[int],
) -> None:
torch.set_default_device("cuda")
current_platform.seed_everything(0)
@ -136,17 +138,20 @@ def test_flashinfer_decode_with_paged_kv(
use_tensor_cores=(
(num_query_heads//num_kv_heads) > 4)
)
wrapper.plan(kv_indptr,
kv_indices,
kv_last_page_lens,
num_query_heads,
num_kv_heads,
head_size,
block_size,
"NONE",
q_data_type=dtype,
kv_data_type=dtype,
logits_soft_cap=soft_cap)
wrapper.plan(
kv_indptr,
kv_indices,
kv_last_page_lens,
num_query_heads,
num_kv_heads,
head_size,
block_size,
"NONE",
window_left=sliding_window - 1 if sliding_window is not None else -1,
q_data_type=dtype,
kv_data_type=dtype,
logits_soft_cap=soft_cap,
)
output = wrapper.run(query, key_value_cache)
@ -157,7 +162,8 @@ def test_flashinfer_decode_with_paged_kv(
kv_lens=kv_lens,
block_tables=block_tables,
scale=scale,
soft_cap=soft_cap)
soft_cap=soft_cap,
sliding_window=sliding_window)
torch.testing.assert_close(output, ref_output, atol=1e-2, rtol=1e-2), \
f"{torch.max(torch.abs(output - ref_output))}"
@ -168,12 +174,17 @@ def test_flashinfer_decode_with_paged_kv(
@pytest.mark.parametrize("block_size", BLOCK_SIZES)
@pytest.mark.parametrize("dtype", DTYPES)
@pytest.mark.parametrize("soft_cap", [None, 30.0, 50.0])
@pytest.mark.parametrize("sliding_window", [None, 64])
@torch.inference_mode
def test_flashinfer_prefill_with_paged_kv(seq_lens: list[tuple[int, int]],
num_heads: tuple[int, int],
head_size: int, dtype: torch.dtype,
block_size: int,
soft_cap: Optional[float]) -> None:
def test_flashinfer_prefill_with_paged_kv(
seq_lens: list[tuple[int, int]],
num_heads: tuple[int, int],
head_size: int,
dtype: torch.dtype,
block_size: int,
soft_cap: Optional[float],
sliding_window: Optional[int],
) -> None:
torch.set_default_device("cuda")
current_platform.seed_everything(0)
num_seqs = len(seq_lens)
@ -242,6 +253,7 @@ def test_flashinfer_prefill_with_paged_kv(seq_lens: list[tuple[int, int]],
num_kv_heads,
head_size,
block_size,
window_left=sliding_window - 1 if sliding_window is not None else -1,
q_data_type=dtype,
kv_data_type=dtype,
logits_soft_cap=soft_cap,
@ -259,7 +271,8 @@ def test_flashinfer_prefill_with_paged_kv(seq_lens: list[tuple[int, int]],
kv_lens=kv_lens,
block_tables=block_tables,
scale=scale,
soft_cap=soft_cap)
soft_cap=soft_cap,
sliding_window=sliding_window)
torch.testing.assert_close(output, ref_output, atol=5e-2, rtol=1e-2), \
f"{torch.max(torch.abs(output - ref_output))}"