mirror of
https://git.datalinker.icu/vllm-project/vllm.git
synced 2025-12-09 04:54:56 +08:00
kernels/moe test pruning (#27053)
Signed-off-by: Fardin Hoque <kfhfar@amazon.com> Co-authored-by: Wentao Ye <44945378+yewentao256@users.noreply.github.com>
This commit is contained in:
parent
17d055f527
commit
b8c48c5d72
@ -24,23 +24,16 @@ from vllm.triton_utils import tl
|
||||
|
||||
MNK_FACTORS = [
|
||||
(1, 128, 128),
|
||||
(1, 128, 2048),
|
||||
(1, 512, 512),
|
||||
(1, 1024, 128),
|
||||
(1, 1024, 2048),
|
||||
(32, 128, 128),
|
||||
(32, 512, 512),
|
||||
(32, 1024, 2048),
|
||||
(45, 128, 128),
|
||||
(45, 128, 2048),
|
||||
(45, 512, 512),
|
||||
(45, 1024, 128),
|
||||
(45, 1024, 2048),
|
||||
(64, 512, 512),
|
||||
(64, 1024, 2048),
|
||||
(222, 128, 128),
|
||||
(222, 128, 2048),
|
||||
(222, 1024, 128),
|
||||
(222, 1024, 2048),
|
||||
]
|
||||
NUM_EXPERTS = [8, 64]
|
||||
@ -117,10 +110,19 @@ def test_batched_mm(
|
||||
block_shape: list[int] | None,
|
||||
per_act_token_quant: bool,
|
||||
):
|
||||
"""Note: float8_e4m3fn is not supported on CUDA architecture < 89,
|
||||
and those tests will be skipped on unsupported hardware."""
|
||||
current_platform.seed_everything(7)
|
||||
|
||||
use_fp8_w8a8 = dtype == torch.float8_e4m3fn
|
||||
|
||||
if (dtype == torch.float8_e4m3fn) and not current_platform.has_device_capability(
|
||||
89
|
||||
):
|
||||
pytest.skip(
|
||||
"Triton limitation: fp8e4nv data type is not supported on CUDA arch < 89"
|
||||
)
|
||||
|
||||
if (per_act_token_quant or block_shape is not None) and not use_fp8_w8a8:
|
||||
pytest.skip("Don't test blocking for non-quantized types.")
|
||||
|
||||
@ -244,10 +246,19 @@ def test_fused_moe_batched_experts(
|
||||
block_shape: list[int] | None,
|
||||
input_scales: bool,
|
||||
):
|
||||
"""Note: float8_e4m3fn is not supported on CUDA architecture < 89,
|
||||
and those tests will be skipped on unsupported hardware."""
|
||||
current_platform.seed_everything(7)
|
||||
|
||||
use_fp8_w8a8 = dtype == torch.float8_e4m3fn
|
||||
|
||||
if (dtype == torch.float8_e4m3fn) and not current_platform.has_device_capability(
|
||||
89
|
||||
):
|
||||
pytest.skip(
|
||||
"Triton limitation: fp8e4nv data type is not supported on CUDA arch < 89"
|
||||
)
|
||||
|
||||
if topk > e:
|
||||
pytest.skip("topk > e")
|
||||
|
||||
|
||||
@ -42,57 +42,43 @@ DTYPES = [torch.bfloat16] # [torch.half, torch.bfloat16, torch.float32]
|
||||
# and its hidden size is 7168.
|
||||
MNK_FACTORS = [
|
||||
(1, 128, 128),
|
||||
(1, 512, 512),
|
||||
(1, 128, 7168),
|
||||
(1, 1024, 7168),
|
||||
(1, 4608, 128),
|
||||
(1, 4608, 512),
|
||||
(1, 4608, 7168),
|
||||
(83, 128, 128),
|
||||
(83, 512, 512),
|
||||
(83, 1024, 7168),
|
||||
(83, 4608, 512),
|
||||
(83, 4608, 7168),
|
||||
(128, 128, 128),
|
||||
(128, 512, 512),
|
||||
(128, 1024, 7168),
|
||||
(128, 4608, 512),
|
||||
(128, 4608, 7168),
|
||||
(2048, 128, 128),
|
||||
(2048, 1024, 7168),
|
||||
(2048, 4608, 512),
|
||||
(2048, 4608, 7168),
|
||||
(8192, 128, 128),
|
||||
(8192, 512, 512),
|
||||
(8192, 128, 7168),
|
||||
(8192, 1024, 7168),
|
||||
(8192, 4608, 512),
|
||||
(8192, 4608, 7168),
|
||||
]
|
||||
|
||||
MNK_FACTORS_DG = [
|
||||
(128, 128, 128),
|
||||
(128, 512, 512),
|
||||
(128, 128, 7168),
|
||||
(128, 1024, 7168),
|
||||
(128, 4608, 128),
|
||||
(128, 4608, 512),
|
||||
(128, 4608, 7168),
|
||||
(192, 128, 128),
|
||||
(192, 512, 512),
|
||||
(192, 1024, 7168),
|
||||
(192, 4608, 512),
|
||||
(192, 4608, 7168),
|
||||
(1335, 128, 128),
|
||||
(1335, 1024, 7168),
|
||||
(1335, 4608, 512),
|
||||
(1335, 4608, 7168),
|
||||
(2048, 128, 128),
|
||||
(2048, 512, 512),
|
||||
(2048, 128, 7168),
|
||||
(2048, 1024, 7168),
|
||||
(2048, 4608, 128),
|
||||
(2048, 4608, 512),
|
||||
(2048, 4608, 7168),
|
||||
]
|
||||
|
||||
|
||||
@ -21,36 +21,28 @@ vllm_config = VllmConfig()
|
||||
vllm_config.scheduler_config.max_num_seqs = 128
|
||||
vllm_config.scheduler_config.max_model_len = 8192
|
||||
|
||||
DTYPES = [torch.half, torch.bfloat16]
|
||||
DTYPES = [torch.bfloat16]
|
||||
|
||||
MNK_FACTORS = [
|
||||
(1, 128, 128),
|
||||
(1, 512, 512),
|
||||
(1, 128, 7168),
|
||||
(1, 1024, 7168),
|
||||
(1, 4096, 128),
|
||||
(1, 4096, 512),
|
||||
(1, 4096, 7168),
|
||||
(33, 128, 128),
|
||||
(33, 512, 512),
|
||||
(33, 128, 7168),
|
||||
(33, 1024, 7168),
|
||||
(33, 4096, 128),
|
||||
(33, 4096, 512),
|
||||
(33, 4096, 7168),
|
||||
(128, 128, 128),
|
||||
(128, 512, 512),
|
||||
(128, 1024, 7168),
|
||||
(128, 4096, 512),
|
||||
(128, 4096, 7168),
|
||||
(222, 128, 128),
|
||||
(222, 512, 512),
|
||||
(222, 1024, 7168),
|
||||
(222, 4096, 512),
|
||||
(222, 4096, 7168),
|
||||
(2048, 128, 128),
|
||||
(2048, 1024, 7168),
|
||||
(2048, 4096, 512),
|
||||
(2048, 4096, 4096),
|
||||
]
|
||||
|
||||
|
||||
@ -26,16 +26,13 @@ TOP_KS = [6, 8]
|
||||
|
||||
MNK_FACTORS = [
|
||||
(2, 1024, 1024),
|
||||
(2, 1024, 1536),
|
||||
(2, 3072, 1024),
|
||||
(2, 3072, 1536),
|
||||
(7, 3072, 1536),
|
||||
(64, 1024, 1024),
|
||||
(64, 1024, 1536),
|
||||
(64, 3072, 1024),
|
||||
(64, 3072, 1536),
|
||||
(224, 1024, 1024),
|
||||
(224, 1024, 1536),
|
||||
(224, 3072, 1024),
|
||||
(224, 3072, 1536),
|
||||
(32768, 1024, 1024),
|
||||
|
||||
@ -393,7 +393,6 @@ def _test_deepep_deepgemm_moe(
|
||||
MNKs = [
|
||||
(8, 128, 128),
|
||||
(8, 128, 512),
|
||||
(8, 512, 512),
|
||||
(3, 1024, 2048),
|
||||
(32, 128, 1024),
|
||||
(45, 512, 2048),
|
||||
|
||||
@ -130,10 +130,8 @@ def run_single_case(m, n, k, topk, num_experts, block_size):
|
||||
# Note: N <= 512 will disable the deepgemm path due to performance issues.
|
||||
MNKs = [
|
||||
(1024, 768, 128),
|
||||
(1024, 768, 512),
|
||||
(2048, 768, 512),
|
||||
(512, 1024, 1024),
|
||||
(512, 2048, 2048),
|
||||
(4096, 4096, 1024),
|
||||
]
|
||||
|
||||
|
||||
@ -34,8 +34,6 @@ TOP_KS = [1]
|
||||
|
||||
MNK_FACTORS = [
|
||||
(256, 8192, 5120),
|
||||
(256, 4096, 5120),
|
||||
(127, 8192, 5120),
|
||||
(127, 4096, 5120),
|
||||
(10, 8192, 5120),
|
||||
(10, 4096, 5120),
|
||||
|
||||
@ -34,10 +34,8 @@ if not has_flashinfer_cutlass_fused_moe() or not current_platform.has_device_cap
|
||||
|
||||
MNK_FACTORS = [
|
||||
(2, 1024, 1024),
|
||||
(2, 1024, 1536),
|
||||
(2, 3072, 1024),
|
||||
(2, 3072, 1536),
|
||||
(64, 1024, 1024),
|
||||
(64, 1024, 1536),
|
||||
(64, 3072, 1024),
|
||||
(64, 2048, 1536),
|
||||
@ -49,7 +47,7 @@ MNK_FACTORS = [
|
||||
@pytest.mark.parametrize("m,n,k", MNK_FACTORS)
|
||||
@pytest.mark.parametrize("e", [40, 64, 256])
|
||||
@pytest.mark.parametrize("topk", [1, 6, 8])
|
||||
@pytest.mark.parametrize("dtype", [torch.half, torch.bfloat16])
|
||||
@pytest.mark.parametrize("dtype", [torch.bfloat16])
|
||||
@torch.inference_mode()
|
||||
def test_flashinfer_fp4_moe_no_graph(
|
||||
m: int, n: int, k: int, e: int, topk: int, dtype: torch.dtype
|
||||
|
||||
@ -27,7 +27,7 @@ from vllm.platforms import current_platform
|
||||
@pytest.mark.parametrize("topk_group", [2])
|
||||
@pytest.mark.parametrize("scoring_func", ["softmax", "sigmoid"])
|
||||
@pytest.mark.parametrize("routed_scaling_factor", [1.0, 2.5])
|
||||
@pytest.mark.parametrize("dtype", [torch.float16, torch.bfloat16, torch.float32])
|
||||
@pytest.mark.parametrize("dtype", [torch.bfloat16, torch.float32])
|
||||
def test_grouped_topk(
|
||||
monkeypatch: pytest.MonkeyPatch,
|
||||
n_token: int,
|
||||
|
||||
@ -295,6 +295,8 @@ def test_modular_kernel_combinations_singlegpu(
|
||||
world_size: int,
|
||||
pytestconfig,
|
||||
):
|
||||
"""Note: float8_e4m3fn is not supported on CUDA architecture < 89,
|
||||
and those tests will be skipped on unsupported hardware."""
|
||||
config = Config(
|
||||
Ms=Ms,
|
||||
K=k,
|
||||
@ -309,6 +311,12 @@ def test_modular_kernel_combinations_singlegpu(
|
||||
world_size=world_size,
|
||||
)
|
||||
|
||||
if (
|
||||
quant_config is not None and quant_config.quant_dtype == torch.float8_e4m3fn
|
||||
) and not current_platform.has_device_capability(89):
|
||||
pytest.skip(
|
||||
"Triton limitation: fp8e4nv data type is not supported on CUDA arch < 89"
|
||||
)
|
||||
verbosity = pytestconfig.getoption("verbose")
|
||||
run(config, verbosity > 0)
|
||||
|
||||
|
||||
@ -66,8 +66,6 @@ FUSED_MOE_MNK_FACTORS = [
|
||||
(1, 128, 128),
|
||||
(1, 2048, 128),
|
||||
(33, 2048, 128),
|
||||
(222, 1024, 1024),
|
||||
(32768, 128, 128),
|
||||
(32768, 2048, 511),
|
||||
(40000, 1024, 1024),
|
||||
]
|
||||
@ -76,7 +74,6 @@ FUSED_MOE_WN16_MNK_FACTORS = [
|
||||
(1, 128, 128),
|
||||
(1, 1024, 1024),
|
||||
(32, 2048, 128),
|
||||
(32, 1024, 1024),
|
||||
(222, 2048, 1024),
|
||||
]
|
||||
|
||||
@ -512,8 +509,8 @@ def marlin_moe_generate_valid_test_cases():
|
||||
e_list = [4, 12]
|
||||
topk_list = [2, 3]
|
||||
ep_size_list = [1, 4]
|
||||
dtype_list = [torch.half, torch.bfloat16]
|
||||
group_size_list = [-1, 16, 32, 128]
|
||||
dtype_list = [torch.bfloat16]
|
||||
group_size_list = [-1, 32, 128]
|
||||
act_order_list = [True, False]
|
||||
quant_type_list = [
|
||||
scalar_types.float4_e2m1f,
|
||||
@ -885,10 +882,10 @@ def test_batched_moe_align_block_size_opcheck():
|
||||
)
|
||||
|
||||
|
||||
@pytest.mark.parametrize("m", [1, 33, 64, 222])
|
||||
@pytest.mark.parametrize("m", [1, 33, 222])
|
||||
@pytest.mark.parametrize("topk", TOP_KS)
|
||||
@pytest.mark.parametrize("k", [128, 511, 1024])
|
||||
@pytest.mark.parametrize("dtype", [torch.float32, torch.float16, torch.bfloat16])
|
||||
@pytest.mark.parametrize("dtype", [torch.float32, torch.bfloat16])
|
||||
@pytest.mark.skipif(current_platform.is_rocm(), reason="Skip for rocm")
|
||||
def test_moe_sum(m: int, topk: int, k: int, dtype: torch.dtype):
|
||||
input = torch.randn((m, topk, k), device="cuda", dtype=dtype)
|
||||
|
||||
@ -26,9 +26,7 @@ MNK_FACTORS = [
|
||||
(2, 1024, 1024),
|
||||
(2, 1024, 1536),
|
||||
(2, 3072, 1024),
|
||||
(2, 3072, 1536),
|
||||
(64, 1024, 1024),
|
||||
(64, 1024, 1536),
|
||||
(64, 3072, 1024),
|
||||
(64, 2048, 1536),
|
||||
(224, 1024, 1024),
|
||||
@ -39,7 +37,7 @@ MNK_FACTORS = [
|
||||
@pytest.mark.parametrize("m,n,k", MNK_FACTORS)
|
||||
@pytest.mark.parametrize("e", [40, 64, 256])
|
||||
@pytest.mark.parametrize("topk", [1, 6, 8])
|
||||
@pytest.mark.parametrize("dtype", [torch.half, torch.bfloat16])
|
||||
@pytest.mark.parametrize("dtype", [torch.bfloat16])
|
||||
@torch.inference_mode()
|
||||
def test_cutlass_fp4_moe_no_graph(
|
||||
m: int, n: int, k: int, e: int, topk: int, dtype: torch.dtype
|
||||
|
||||
@ -19,20 +19,16 @@ CASES = [
|
||||
(32, 64, 256, fp8_dtype),
|
||||
(17, 31, 768, fp8_dtype),
|
||||
(1, 1, 128 * 1, fp8_dtype),
|
||||
(1, 1, 128 * 2, fp8_dtype),
|
||||
(1, 1, 128 * 3, fp8_dtype),
|
||||
(1, 1, 128 * 4, fp8_dtype),
|
||||
(8, 16, 128 * 1, fp8_dtype),
|
||||
(8, 16, 128 * 2, fp8_dtype),
|
||||
(8, 16, 128 * 3, fp8_dtype),
|
||||
(8, 16, 128 * 4, fp8_dtype),
|
||||
(8, 64, 7168, fp8_dtype),
|
||||
(8, 128, 7168, fp8_dtype),
|
||||
(8, 256, 7168, fp8_dtype),
|
||||
(8, 512, 7168, fp8_dtype),
|
||||
(8, 1024, 7168, fp8_dtype),
|
||||
(256, 8, 7168, fp8_dtype),
|
||||
(256, 16, 7168, fp8_dtype),
|
||||
(256, 32, 7168, fp8_dtype),
|
||||
(256, 64, 7168, fp8_dtype),
|
||||
# Only add a few fnuz tests to help with long CI times.
|
||||
|
||||
Loading…
x
Reference in New Issue
Block a user