From 6dda13c86ba17ca6bc054293d135bad2d1ab7129 Mon Sep 17 00:00:00 2001 From: Woosuk Kwon Date: Mon, 21 Jul 2025 08:37:49 -0700 Subject: [PATCH] [Misc] Add sliding window to flashinfer test (#21282) Signed-off-by: Woosuk Kwon --- tests/kernels/attention/test_flashinfer.py | 49 ++++++++++++++-------- 1 file changed, 31 insertions(+), 18 deletions(-) diff --git a/tests/kernels/attention/test_flashinfer.py b/tests/kernels/attention/test_flashinfer.py index 3ad6e1d32911b..8f9b4eceaa72b 100644 --- a/tests/kernels/attention/test_flashinfer.py +++ b/tests/kernels/attention/test_flashinfer.py @@ -77,6 +77,7 @@ def ref_paged_attn( @pytest.mark.parametrize("block_size", BLOCK_SIZES) @pytest.mark.parametrize("dtype", DTYPES) @pytest.mark.parametrize("soft_cap", [None, 30.0, 50.0]) +@pytest.mark.parametrize("sliding_window", [None, 64]) @torch.inference_mode def test_flashinfer_decode_with_paged_kv( kv_lens: list[int], @@ -85,6 +86,7 @@ def test_flashinfer_decode_with_paged_kv( dtype: torch.dtype, block_size: int, soft_cap: Optional[float], + sliding_window: Optional[int], ) -> None: torch.set_default_device("cuda") current_platform.seed_everything(0) @@ -136,17 +138,20 @@ def test_flashinfer_decode_with_paged_kv( use_tensor_cores=( (num_query_heads//num_kv_heads) > 4) ) - wrapper.plan(kv_indptr, - kv_indices, - kv_last_page_lens, - num_query_heads, - num_kv_heads, - head_size, - block_size, - "NONE", - q_data_type=dtype, - kv_data_type=dtype, - logits_soft_cap=soft_cap) + wrapper.plan( + kv_indptr, + kv_indices, + kv_last_page_lens, + num_query_heads, + num_kv_heads, + head_size, + block_size, + "NONE", + window_left=sliding_window - 1 if sliding_window is not None else -1, + q_data_type=dtype, + kv_data_type=dtype, + logits_soft_cap=soft_cap, + ) output = wrapper.run(query, key_value_cache) @@ -157,7 +162,8 @@ def test_flashinfer_decode_with_paged_kv( kv_lens=kv_lens, block_tables=block_tables, scale=scale, - soft_cap=soft_cap) + soft_cap=soft_cap, + sliding_window=sliding_window) torch.testing.assert_close(output, ref_output, atol=1e-2, rtol=1e-2), \ f"{torch.max(torch.abs(output - ref_output))}" @@ -168,12 +174,17 @@ def test_flashinfer_decode_with_paged_kv( @pytest.mark.parametrize("block_size", BLOCK_SIZES) @pytest.mark.parametrize("dtype", DTYPES) @pytest.mark.parametrize("soft_cap", [None, 30.0, 50.0]) +@pytest.mark.parametrize("sliding_window", [None, 64]) @torch.inference_mode -def test_flashinfer_prefill_with_paged_kv(seq_lens: list[tuple[int, int]], - num_heads: tuple[int, int], - head_size: int, dtype: torch.dtype, - block_size: int, - soft_cap: Optional[float]) -> None: +def test_flashinfer_prefill_with_paged_kv( + seq_lens: list[tuple[int, int]], + num_heads: tuple[int, int], + head_size: int, + dtype: torch.dtype, + block_size: int, + soft_cap: Optional[float], + sliding_window: Optional[int], +) -> None: torch.set_default_device("cuda") current_platform.seed_everything(0) num_seqs = len(seq_lens) @@ -242,6 +253,7 @@ def test_flashinfer_prefill_with_paged_kv(seq_lens: list[tuple[int, int]], num_kv_heads, head_size, block_size, + window_left=sliding_window - 1 if sliding_window is not None else -1, q_data_type=dtype, kv_data_type=dtype, logits_soft_cap=soft_cap, @@ -259,7 +271,8 @@ def test_flashinfer_prefill_with_paged_kv(seq_lens: list[tuple[int, int]], kv_lens=kv_lens, block_tables=block_tables, scale=scale, - soft_cap=soft_cap) + soft_cap=soft_cap, + sliding_window=sliding_window) torch.testing.assert_close(output, ref_output, atol=5e-2, rtol=1e-2), \ f"{torch.max(torch.abs(output - ref_output))}"