mirror of
https://git.datalinker.icu/vllm-project/vllm.git
synced 2025-12-10 12:35:51 +08:00
[CI Perf] Prune tests in tests/kernels/quantization/ (#22942)
Signed-off-by: mgoin <mgoin64@gmail.com>
This commit is contained in:
parent
ae05a6d83d
commit
590bddbfc5
@ -11,11 +11,9 @@ from tests.kernels.quant_utils import (FP8_DTYPE,
|
||||
from tests.kernels.utils import opcheck
|
||||
from vllm.platforms import current_platform
|
||||
|
||||
DTYPES = [torch.half, torch.bfloat16, torch.float]
|
||||
HIDDEN_SIZES = [1, 2, 3, 4, 16, 67, 768, 2048, 5120, 5137, 8192,
|
||||
8193] # Arbitrary values for testing
|
||||
HIDDEN_SIZES += list(range(1024, 1033)) # vectorized conversion edge cases
|
||||
NUM_TOKENS = [1, 7, 83, 4096] # Arbitrary values for testing
|
||||
DTYPES = [torch.bfloat16, torch.float]
|
||||
HIDDEN_SIZES = [17, 1024, 1025, 1026, 5137, 8193]
|
||||
NUM_TOKENS = [1, 7, 4096]
|
||||
SCALE_UBS = [True, False]
|
||||
SEEDS = [0]
|
||||
|
||||
|
||||
@ -9,10 +9,9 @@ from tests.kernels.utils import opcheck
|
||||
from vllm._custom_ops import scaled_int8_quant
|
||||
from vllm.platforms import current_platform
|
||||
|
||||
DTYPES = [torch.half, torch.bfloat16, torch.float]
|
||||
HIDDEN_SIZES = [16, 67, 768, 5137, 8193] # Arbitrary values for testing
|
||||
HIDDEN_SIZES += list(range(1024, 1033)) # vectorized conversion edge cases
|
||||
NUM_TOKENS = [1, 7, 83, 4096] # Arbitrary values for testing
|
||||
DTYPES = [torch.bfloat16, torch.float]
|
||||
HIDDEN_SIZES = [17, 1024, 1025, 1026, 5137, 8193]
|
||||
NUM_TOKENS = [1, 7, 4096]
|
||||
SEEDS = [0]
|
||||
SCALE = [0.1, 2.1]
|
||||
|
||||
|
||||
@ -34,8 +34,6 @@ IS_SUPPORTED_BY_GPU = current_platform.get_device_capability()[0] >= 9
|
||||
|
||||
MNK_SHAPES = [
|
||||
(1, 128, 128),
|
||||
(1, 512, 1024),
|
||||
(1, 4096, 4096),
|
||||
(1, 8192, 28672),
|
||||
(13, 8192, 4096),
|
||||
(26, 4096, 8192),
|
||||
@ -43,8 +41,6 @@ MNK_SHAPES = [
|
||||
(64, 8192, 28672),
|
||||
(257, 128, 4096),
|
||||
(257, 4224, 4160),
|
||||
(257, 4096, 4096),
|
||||
(1024, 4096, 8192),
|
||||
(1024, 8192, 4096),
|
||||
]
|
||||
|
||||
|
||||
@ -53,12 +53,8 @@ HQQ_SUPPORTED_GROUP_SIZES = [64]
|
||||
MNK_FACTORS = [
|
||||
(1, 1, 1),
|
||||
(1, 4, 8),
|
||||
(1, 7, 5),
|
||||
(13, 17, 67),
|
||||
(26, 37, 13),
|
||||
(67, 13, 11),
|
||||
(257, 13, 11),
|
||||
(658, 13, 11),
|
||||
]
|
||||
|
||||
DTYPES = [torch.float16, torch.bfloat16]
|
||||
|
||||
@ -8,15 +8,55 @@ from tests.kernels.quant_utils import ref_dynamic_per_tensor_fp8_quant
|
||||
from vllm.platforms import current_platform
|
||||
|
||||
DTYPES = [torch.bfloat16, torch.float16]
|
||||
M = [16, 32, 64, 128, 256, 512, 1024, 4096, 8192]
|
||||
K = [8, 16, 32, 64, 128, 256, 512, 1024, 2048, 4096, 6144, 8192] # k % 8 == 0
|
||||
N = [1, 2, 3, 4]
|
||||
# Specific (N, K, M) combinations for targeted testing
|
||||
NKM_FACTORS_LLMM1 = [
|
||||
# Small, medium, large cases
|
||||
(1, 8, 16),
|
||||
(1, 32, 64),
|
||||
(1, 128, 256),
|
||||
(1, 512, 1024),
|
||||
(1, 2048, 4096),
|
||||
# Edge cases with specific K sizes
|
||||
(1, 6144, 1024),
|
||||
(1, 8192, 2048),
|
||||
# Very large case
|
||||
(1, 4096, 8192),
|
||||
]
|
||||
|
||||
NKM_FACTORS_WVSPLITK = [
|
||||
# Different batch sizes with key dimensions
|
||||
(1, 16, 16),
|
||||
(1, 64, 64),
|
||||
(2, 256, 256),
|
||||
(3, 1024, 1024),
|
||||
(4, 4096, 4096),
|
||||
# Extended K values
|
||||
(1, 9216, 512),
|
||||
(2, 10240, 1024),
|
||||
(4, 16384, 8192),
|
||||
# Minimum M constraint validation (m >= 8)
|
||||
(1, 64, 8),
|
||||
(2, 128, 8),
|
||||
(4, 256, 8),
|
||||
]
|
||||
|
||||
NKM_FACTORS_WVSPLITK_FP8 = [
|
||||
# FP8-specific cases with K % 16 == 0
|
||||
(1, 16, 16),
|
||||
(1, 64, 64),
|
||||
(2, 512, 512),
|
||||
(3, 2048, 2048),
|
||||
(4, 4096, 4096),
|
||||
# Extended FP8 dimensions not covered by WVSPLITK
|
||||
(1, 14336, 1024),
|
||||
(2, 24576, 2048),
|
||||
(4, 32768, 28672),
|
||||
]
|
||||
|
||||
SEEDS = [0]
|
||||
|
||||
|
||||
@pytest.mark.parametrize("n", [1]) # only test for batch size 1
|
||||
@pytest.mark.parametrize("k", K)
|
||||
@pytest.mark.parametrize("m", M)
|
||||
@pytest.mark.parametrize("n,k,m", NKM_FACTORS_LLMM1)
|
||||
@pytest.mark.parametrize("dtype", DTYPES)
|
||||
@pytest.mark.parametrize("rows_per_block", [2, 4, 8, 16])
|
||||
@pytest.mark.parametrize("seed", SEEDS)
|
||||
@ -34,9 +74,7 @@ def test_rocm_llmm1_kernel(n, k, m, dtype, rows_per_block, seed):
|
||||
assert torch.allclose(out, ref_out, rtol=0.01)
|
||||
|
||||
|
||||
@pytest.mark.parametrize("n", N) # only test for batch size <= 4
|
||||
@pytest.mark.parametrize("k", K + [9216, 10240, 16384])
|
||||
@pytest.mark.parametrize("m", [8] + M) # m >= 8
|
||||
@pytest.mark.parametrize("n,k,m", NKM_FACTORS_WVSPLITK)
|
||||
@pytest.mark.parametrize("dtype", DTYPES)
|
||||
@pytest.mark.parametrize("seed", SEEDS)
|
||||
@pytest.mark.skipif(not current_platform.is_rocm(),
|
||||
@ -54,9 +92,7 @@ def test_rocm_wvsplitk_kernel(n, k, m, dtype, seed):
|
||||
assert torch.allclose(out, ref_out, rtol=0.01)
|
||||
|
||||
|
||||
@pytest.mark.parametrize("n", N) # only test for batch size <= 4
|
||||
@pytest.mark.parametrize("k", K[1:] + [14336, 24576, 32768]) # k % 16 == 0
|
||||
@pytest.mark.parametrize("m", M + [28672]) # m >= 16
|
||||
@pytest.mark.parametrize("n,k,m", NKM_FACTORS_WVSPLITK_FP8)
|
||||
@pytest.mark.parametrize("dtype", DTYPES)
|
||||
@pytest.mark.parametrize("seed", SEEDS)
|
||||
@pytest.mark.skipif(
|
||||
|
||||
@ -60,10 +60,18 @@ def test_rocm_compressed_tensors_w8a8(vllm_runner, example_prompts, model_path,
|
||||
num_logprobs)
|
||||
|
||||
|
||||
@pytest.mark.parametrize("M", [1, 33, 64, 512])
|
||||
@pytest.mark.parametrize("N", [256, 971, 20486])
|
||||
@pytest.mark.parametrize("K", [128, 496, 1024])
|
||||
@pytest.mark.parametrize("out_dtype", [torch.float16, torch.bfloat16])
|
||||
MNK_FACTORS = [
|
||||
(1, 256, 128),
|
||||
(33, 256, 496),
|
||||
(64, 971, 1024),
|
||||
(64, 20486, 128),
|
||||
(512, 256, 496),
|
||||
(512, 20486, 1024),
|
||||
]
|
||||
|
||||
|
||||
@pytest.mark.parametrize("M,N,K", MNK_FACTORS)
|
||||
@pytest.mark.parametrize("out_dtype", [torch.bfloat16])
|
||||
@pytest.mark.parametrize("in_dtype", get_8bit_types())
|
||||
@pytest.mark.parametrize("use_scalar_scale_a", [True, False])
|
||||
@pytest.mark.parametrize("use_scalar_scale_b", [True, False])
|
||||
|
||||
Loading…
x
Reference in New Issue
Block a user