[Misc] Add sliding window to flashinfer test (#21282)

Signed-off-by: Woosuk Kwon <woosuk.kwon@berkeley.edu>
This commit is contained in:
Woosuk Kwon 2025-07-21 08:37:49 -07:00 committed by GitHub
parent 6b46c4b653
commit 6dda13c86b
No known key found for this signature in database
GPG Key ID: B5690EEEBB952194

View File

@ -77,6 +77,7 @@ def ref_paged_attn(
@pytest.mark.parametrize("block_size", BLOCK_SIZES) @pytest.mark.parametrize("block_size", BLOCK_SIZES)
@pytest.mark.parametrize("dtype", DTYPES) @pytest.mark.parametrize("dtype", DTYPES)
@pytest.mark.parametrize("soft_cap", [None, 30.0, 50.0]) @pytest.mark.parametrize("soft_cap", [None, 30.0, 50.0])
@pytest.mark.parametrize("sliding_window", [None, 64])
@torch.inference_mode @torch.inference_mode
def test_flashinfer_decode_with_paged_kv( def test_flashinfer_decode_with_paged_kv(
kv_lens: list[int], kv_lens: list[int],
@ -85,6 +86,7 @@ def test_flashinfer_decode_with_paged_kv(
dtype: torch.dtype, dtype: torch.dtype,
block_size: int, block_size: int,
soft_cap: Optional[float], soft_cap: Optional[float],
sliding_window: Optional[int],
) -> None: ) -> None:
torch.set_default_device("cuda") torch.set_default_device("cuda")
current_platform.seed_everything(0) current_platform.seed_everything(0)
@ -136,17 +138,20 @@ def test_flashinfer_decode_with_paged_kv(
use_tensor_cores=( use_tensor_cores=(
(num_query_heads//num_kv_heads) > 4) (num_query_heads//num_kv_heads) > 4)
) )
wrapper.plan(kv_indptr, wrapper.plan(
kv_indices, kv_indptr,
kv_last_page_lens, kv_indices,
num_query_heads, kv_last_page_lens,
num_kv_heads, num_query_heads,
head_size, num_kv_heads,
block_size, head_size,
"NONE", block_size,
q_data_type=dtype, "NONE",
kv_data_type=dtype, window_left=sliding_window - 1 if sliding_window is not None else -1,
logits_soft_cap=soft_cap) q_data_type=dtype,
kv_data_type=dtype,
logits_soft_cap=soft_cap,
)
output = wrapper.run(query, key_value_cache) output = wrapper.run(query, key_value_cache)
@ -157,7 +162,8 @@ def test_flashinfer_decode_with_paged_kv(
kv_lens=kv_lens, kv_lens=kv_lens,
block_tables=block_tables, block_tables=block_tables,
scale=scale, scale=scale,
soft_cap=soft_cap) soft_cap=soft_cap,
sliding_window=sliding_window)
torch.testing.assert_close(output, ref_output, atol=1e-2, rtol=1e-2), \ torch.testing.assert_close(output, ref_output, atol=1e-2, rtol=1e-2), \
f"{torch.max(torch.abs(output - ref_output))}" f"{torch.max(torch.abs(output - ref_output))}"
@ -168,12 +174,17 @@ def test_flashinfer_decode_with_paged_kv(
@pytest.mark.parametrize("block_size", BLOCK_SIZES) @pytest.mark.parametrize("block_size", BLOCK_SIZES)
@pytest.mark.parametrize("dtype", DTYPES) @pytest.mark.parametrize("dtype", DTYPES)
@pytest.mark.parametrize("soft_cap", [None, 30.0, 50.0]) @pytest.mark.parametrize("soft_cap", [None, 30.0, 50.0])
@pytest.mark.parametrize("sliding_window", [None, 64])
@torch.inference_mode @torch.inference_mode
def test_flashinfer_prefill_with_paged_kv(seq_lens: list[tuple[int, int]], def test_flashinfer_prefill_with_paged_kv(
num_heads: tuple[int, int], seq_lens: list[tuple[int, int]],
head_size: int, dtype: torch.dtype, num_heads: tuple[int, int],
block_size: int, head_size: int,
soft_cap: Optional[float]) -> None: dtype: torch.dtype,
block_size: int,
soft_cap: Optional[float],
sliding_window: Optional[int],
) -> None:
torch.set_default_device("cuda") torch.set_default_device("cuda")
current_platform.seed_everything(0) current_platform.seed_everything(0)
num_seqs = len(seq_lens) num_seqs = len(seq_lens)
@ -242,6 +253,7 @@ def test_flashinfer_prefill_with_paged_kv(seq_lens: list[tuple[int, int]],
num_kv_heads, num_kv_heads,
head_size, head_size,
block_size, block_size,
window_left=sliding_window - 1 if sliding_window is not None else -1,
q_data_type=dtype, q_data_type=dtype,
kv_data_type=dtype, kv_data_type=dtype,
logits_soft_cap=soft_cap, logits_soft_cap=soft_cap,
@ -259,7 +271,8 @@ def test_flashinfer_prefill_with_paged_kv(seq_lens: list[tuple[int, int]],
kv_lens=kv_lens, kv_lens=kv_lens,
block_tables=block_tables, block_tables=block_tables,
scale=scale, scale=scale,
soft_cap=soft_cap) soft_cap=soft_cap,
sliding_window=sliding_window)
torch.testing.assert_close(output, ref_output, atol=5e-2, rtol=1e-2), \ torch.testing.assert_close(output, ref_output, atol=5e-2, rtol=1e-2), \
f"{torch.max(torch.abs(output - ref_output))}" f"{torch.max(torch.abs(output - ref_output))}"