mirror of
https://git.datalinker.icu/vllm-project/vllm.git
synced 2026-06-09 12:22:19 +08:00
[Misc] Add sliding window to flashinfer test (#21282)
Signed-off-by: Woosuk Kwon <woosuk.kwon@berkeley.edu>
This commit is contained in:
parent
6b46c4b653
commit
6dda13c86b
@ -77,6 +77,7 @@ def ref_paged_attn(
|
|||||||
@pytest.mark.parametrize("block_size", BLOCK_SIZES)
|
@pytest.mark.parametrize("block_size", BLOCK_SIZES)
|
||||||
@pytest.mark.parametrize("dtype", DTYPES)
|
@pytest.mark.parametrize("dtype", DTYPES)
|
||||||
@pytest.mark.parametrize("soft_cap", [None, 30.0, 50.0])
|
@pytest.mark.parametrize("soft_cap", [None, 30.0, 50.0])
|
||||||
|
@pytest.mark.parametrize("sliding_window", [None, 64])
|
||||||
@torch.inference_mode
|
@torch.inference_mode
|
||||||
def test_flashinfer_decode_with_paged_kv(
|
def test_flashinfer_decode_with_paged_kv(
|
||||||
kv_lens: list[int],
|
kv_lens: list[int],
|
||||||
@ -85,6 +86,7 @@ def test_flashinfer_decode_with_paged_kv(
|
|||||||
dtype: torch.dtype,
|
dtype: torch.dtype,
|
||||||
block_size: int,
|
block_size: int,
|
||||||
soft_cap: Optional[float],
|
soft_cap: Optional[float],
|
||||||
|
sliding_window: Optional[int],
|
||||||
) -> None:
|
) -> None:
|
||||||
torch.set_default_device("cuda")
|
torch.set_default_device("cuda")
|
||||||
current_platform.seed_everything(0)
|
current_platform.seed_everything(0)
|
||||||
@ -136,17 +138,20 @@ def test_flashinfer_decode_with_paged_kv(
|
|||||||
use_tensor_cores=(
|
use_tensor_cores=(
|
||||||
(num_query_heads//num_kv_heads) > 4)
|
(num_query_heads//num_kv_heads) > 4)
|
||||||
)
|
)
|
||||||
wrapper.plan(kv_indptr,
|
wrapper.plan(
|
||||||
kv_indices,
|
kv_indptr,
|
||||||
kv_last_page_lens,
|
kv_indices,
|
||||||
num_query_heads,
|
kv_last_page_lens,
|
||||||
num_kv_heads,
|
num_query_heads,
|
||||||
head_size,
|
num_kv_heads,
|
||||||
block_size,
|
head_size,
|
||||||
"NONE",
|
block_size,
|
||||||
q_data_type=dtype,
|
"NONE",
|
||||||
kv_data_type=dtype,
|
window_left=sliding_window - 1 if sliding_window is not None else -1,
|
||||||
logits_soft_cap=soft_cap)
|
q_data_type=dtype,
|
||||||
|
kv_data_type=dtype,
|
||||||
|
logits_soft_cap=soft_cap,
|
||||||
|
)
|
||||||
|
|
||||||
output = wrapper.run(query, key_value_cache)
|
output = wrapper.run(query, key_value_cache)
|
||||||
|
|
||||||
@ -157,7 +162,8 @@ def test_flashinfer_decode_with_paged_kv(
|
|||||||
kv_lens=kv_lens,
|
kv_lens=kv_lens,
|
||||||
block_tables=block_tables,
|
block_tables=block_tables,
|
||||||
scale=scale,
|
scale=scale,
|
||||||
soft_cap=soft_cap)
|
soft_cap=soft_cap,
|
||||||
|
sliding_window=sliding_window)
|
||||||
torch.testing.assert_close(output, ref_output, atol=1e-2, rtol=1e-2), \
|
torch.testing.assert_close(output, ref_output, atol=1e-2, rtol=1e-2), \
|
||||||
f"{torch.max(torch.abs(output - ref_output))}"
|
f"{torch.max(torch.abs(output - ref_output))}"
|
||||||
|
|
||||||
@ -168,12 +174,17 @@ def test_flashinfer_decode_with_paged_kv(
|
|||||||
@pytest.mark.parametrize("block_size", BLOCK_SIZES)
|
@pytest.mark.parametrize("block_size", BLOCK_SIZES)
|
||||||
@pytest.mark.parametrize("dtype", DTYPES)
|
@pytest.mark.parametrize("dtype", DTYPES)
|
||||||
@pytest.mark.parametrize("soft_cap", [None, 30.0, 50.0])
|
@pytest.mark.parametrize("soft_cap", [None, 30.0, 50.0])
|
||||||
|
@pytest.mark.parametrize("sliding_window", [None, 64])
|
||||||
@torch.inference_mode
|
@torch.inference_mode
|
||||||
def test_flashinfer_prefill_with_paged_kv(seq_lens: list[tuple[int, int]],
|
def test_flashinfer_prefill_with_paged_kv(
|
||||||
num_heads: tuple[int, int],
|
seq_lens: list[tuple[int, int]],
|
||||||
head_size: int, dtype: torch.dtype,
|
num_heads: tuple[int, int],
|
||||||
block_size: int,
|
head_size: int,
|
||||||
soft_cap: Optional[float]) -> None:
|
dtype: torch.dtype,
|
||||||
|
block_size: int,
|
||||||
|
soft_cap: Optional[float],
|
||||||
|
sliding_window: Optional[int],
|
||||||
|
) -> None:
|
||||||
torch.set_default_device("cuda")
|
torch.set_default_device("cuda")
|
||||||
current_platform.seed_everything(0)
|
current_platform.seed_everything(0)
|
||||||
num_seqs = len(seq_lens)
|
num_seqs = len(seq_lens)
|
||||||
@ -242,6 +253,7 @@ def test_flashinfer_prefill_with_paged_kv(seq_lens: list[tuple[int, int]],
|
|||||||
num_kv_heads,
|
num_kv_heads,
|
||||||
head_size,
|
head_size,
|
||||||
block_size,
|
block_size,
|
||||||
|
window_left=sliding_window - 1 if sliding_window is not None else -1,
|
||||||
q_data_type=dtype,
|
q_data_type=dtype,
|
||||||
kv_data_type=dtype,
|
kv_data_type=dtype,
|
||||||
logits_soft_cap=soft_cap,
|
logits_soft_cap=soft_cap,
|
||||||
@ -259,7 +271,8 @@ def test_flashinfer_prefill_with_paged_kv(seq_lens: list[tuple[int, int]],
|
|||||||
kv_lens=kv_lens,
|
kv_lens=kv_lens,
|
||||||
block_tables=block_tables,
|
block_tables=block_tables,
|
||||||
scale=scale,
|
scale=scale,
|
||||||
soft_cap=soft_cap)
|
soft_cap=soft_cap,
|
||||||
|
sliding_window=sliding_window)
|
||||||
torch.testing.assert_close(output, ref_output, atol=5e-2, rtol=1e-2), \
|
torch.testing.assert_close(output, ref_output, atol=5e-2, rtol=1e-2), \
|
||||||
f"{torch.max(torch.abs(output - ref_output))}"
|
f"{torch.max(torch.abs(output - ref_output))}"
|
||||||
|
|
||||||
|
|||||||
Loading…
x
Reference in New Issue
Block a user