mirror of
https://git.datalinker.icu/vllm-project/vllm.git
synced 2026-05-08 10:00:13 +08:00
[cpu][ci] Add CPU Attention Tests for Neon Backend (#30347)
Signed-off-by: Fadi Arafeh <fadi.arafeh@arm.com>
This commit is contained in:
parent
ed7af3178a
commit
434ac76a7c
@ -7,7 +7,8 @@ import math
|
|||||||
import pytest
|
import pytest
|
||||||
import torch
|
import torch
|
||||||
|
|
||||||
from vllm.platforms import current_platform
|
from vllm.platforms import CpuArchEnum, current_platform
|
||||||
|
from vllm.v1.attention.backends.cpu_attn import _get_attn_isa
|
||||||
|
|
||||||
if not current_platform.is_cpu():
|
if not current_platform.is_cpu():
|
||||||
pytest.skip("skipping CPU-only tests", allow_module_level=True)
|
pytest.skip("skipping CPU-only tests", allow_module_level=True)
|
||||||
@ -36,6 +37,21 @@ SEQ_LENS = [ # (q_len, kv_len)
|
|||||||
]
|
]
|
||||||
|
|
||||||
|
|
||||||
|
def get_attn_isa(
|
||||||
|
block_size: int | None = None,
|
||||||
|
dtype: torch.dtype | None = None,
|
||||||
|
):
|
||||||
|
if block_size and dtype:
|
||||||
|
return _get_attn_isa(dtype, block_size)
|
||||||
|
else:
|
||||||
|
if current_platform.get_cpu_architecture() == CpuArchEnum.ARM:
|
||||||
|
return "neon"
|
||||||
|
elif torch._C._cpu._is_amx_tile_supported():
|
||||||
|
return "amx"
|
||||||
|
else:
|
||||||
|
return "vec"
|
||||||
|
|
||||||
|
|
||||||
# rand number generation takes too much time, cache rand tensors
|
# rand number generation takes too much time, cache rand tensors
|
||||||
@functools.lru_cache(maxsize=128, typed=False)
|
@functools.lru_cache(maxsize=128, typed=False)
|
||||||
def tensor_cache(
|
def tensor_cache(
|
||||||
@ -452,6 +468,49 @@ def test_varlen_with_paged_kv_normal_vec16(
|
|||||||
)
|
)
|
||||||
|
|
||||||
|
|
||||||
|
@pytest.mark.parametrize("seq_lens", SEQ_LENS)
|
||||||
|
@pytest.mark.parametrize("num_heads", NUM_HEADS)
|
||||||
|
@pytest.mark.parametrize("head_size", HEAD_SIZES)
|
||||||
|
@pytest.mark.parametrize("block_size", [96, 128])
|
||||||
|
@pytest.mark.parametrize("sliding_window", SLIDING_WINDOWS)
|
||||||
|
@pytest.mark.parametrize("dtype", QTYPES)
|
||||||
|
@pytest.mark.parametrize("soft_cap", [None])
|
||||||
|
@pytest.mark.parametrize("num_blocks", NUM_BLOCKS)
|
||||||
|
@pytest.mark.parametrize("use_alibi", [False])
|
||||||
|
@pytest.mark.parametrize("use_sink", [False])
|
||||||
|
@pytest.mark.parametrize("isa", ["neon"])
|
||||||
|
@pytest.mark.skipif(
|
||||||
|
current_platform.get_cpu_architecture() != CpuArchEnum.ARM,
|
||||||
|
reason="Not an Arm CPU.",
|
||||||
|
)
|
||||||
|
def test_varlen_with_paged_kv_normal_neon(
|
||||||
|
seq_lens: list[tuple[int, int]],
|
||||||
|
num_heads: tuple[int, int],
|
||||||
|
head_size: int,
|
||||||
|
sliding_window: int | None,
|
||||||
|
dtype: torch.dtype,
|
||||||
|
block_size: int,
|
||||||
|
soft_cap: float | None,
|
||||||
|
num_blocks: int,
|
||||||
|
use_alibi: bool,
|
||||||
|
use_sink: bool,
|
||||||
|
isa: str,
|
||||||
|
) -> None:
|
||||||
|
varlen_with_paged_kv(
|
||||||
|
seq_lens=seq_lens,
|
||||||
|
num_heads=num_heads,
|
||||||
|
head_size=head_size,
|
||||||
|
sliding_window=sliding_window,
|
||||||
|
dtype=dtype,
|
||||||
|
block_size=block_size,
|
||||||
|
soft_cap=soft_cap,
|
||||||
|
num_blocks=num_blocks,
|
||||||
|
use_alibi=use_alibi,
|
||||||
|
use_sink=use_sink,
|
||||||
|
isa=isa,
|
||||||
|
)
|
||||||
|
|
||||||
|
|
||||||
@pytest.mark.parametrize("seq_lens", SEQ_LENS)
|
@pytest.mark.parametrize("seq_lens", SEQ_LENS)
|
||||||
@pytest.mark.parametrize("num_heads", NUM_HEADS)
|
@pytest.mark.parametrize("num_heads", NUM_HEADS)
|
||||||
@pytest.mark.parametrize("head_size", [96])
|
@pytest.mark.parametrize("head_size", [96])
|
||||||
@ -462,9 +521,7 @@ def test_varlen_with_paged_kv_normal_vec16(
|
|||||||
@pytest.mark.parametrize("num_blocks", NUM_BLOCKS)
|
@pytest.mark.parametrize("num_blocks", NUM_BLOCKS)
|
||||||
@pytest.mark.parametrize("use_alibi", [False])
|
@pytest.mark.parametrize("use_alibi", [False])
|
||||||
@pytest.mark.parametrize("use_sink", [False])
|
@pytest.mark.parametrize("use_sink", [False])
|
||||||
@pytest.mark.parametrize(
|
@pytest.mark.parametrize("isa", [get_attn_isa()])
|
||||||
"isa", ["amx"] if torch._C._cpu._is_amx_tile_supported() else ["vec"]
|
|
||||||
)
|
|
||||||
def test_varlen_with_paged_kv_softcap(
|
def test_varlen_with_paged_kv_softcap(
|
||||||
seq_lens: list[tuple[int, int]],
|
seq_lens: list[tuple[int, int]],
|
||||||
num_heads: tuple[int, int],
|
num_heads: tuple[int, int],
|
||||||
@ -503,9 +560,7 @@ def test_varlen_with_paged_kv_softcap(
|
|||||||
@pytest.mark.parametrize("num_blocks", NUM_BLOCKS)
|
@pytest.mark.parametrize("num_blocks", NUM_BLOCKS)
|
||||||
@pytest.mark.parametrize("use_alibi", [True])
|
@pytest.mark.parametrize("use_alibi", [True])
|
||||||
@pytest.mark.parametrize("use_sink", [False])
|
@pytest.mark.parametrize("use_sink", [False])
|
||||||
@pytest.mark.parametrize(
|
@pytest.mark.parametrize("isa", [get_attn_isa()])
|
||||||
"isa", ["amx"] if torch._C._cpu._is_amx_tile_supported() else ["vec"]
|
|
||||||
)
|
|
||||||
def test_varlen_with_paged_kv_alibi(
|
def test_varlen_with_paged_kv_alibi(
|
||||||
seq_lens: list[tuple[int, int]],
|
seq_lens: list[tuple[int, int]],
|
||||||
num_heads: tuple[int, int],
|
num_heads: tuple[int, int],
|
||||||
@ -544,9 +599,7 @@ def test_varlen_with_paged_kv_alibi(
|
|||||||
@pytest.mark.parametrize("num_blocks", NUM_BLOCKS)
|
@pytest.mark.parametrize("num_blocks", NUM_BLOCKS)
|
||||||
@pytest.mark.parametrize("use_alibi", [False])
|
@pytest.mark.parametrize("use_alibi", [False])
|
||||||
@pytest.mark.parametrize("use_sink", [True])
|
@pytest.mark.parametrize("use_sink", [True])
|
||||||
@pytest.mark.parametrize(
|
@pytest.mark.parametrize("isa", [get_attn_isa()])
|
||||||
"isa", ["amx"] if torch._C._cpu._is_amx_tile_supported() else ["vec"]
|
|
||||||
)
|
|
||||||
def test_varlen_with_paged_kv_sink(
|
def test_varlen_with_paged_kv_sink(
|
||||||
seq_lens: list[tuple[int, int]],
|
seq_lens: list[tuple[int, int]],
|
||||||
num_heads: tuple[int, int],
|
num_heads: tuple[int, int],
|
||||||
|
|||||||
Loading…
x
Reference in New Issue
Block a user