kernels/moe test pruning (#27053)

Signed-off-by: Fardin Hoque <kfhfar@amazon.com>
Co-authored-by: Wentao Ye <44945378+yewentao256@users.noreply.github.com>
This commit is contained in:
Fardin Hoque 2025-10-29 21:10:34 -07:00 committed by GitHub
parent 17d055f527
commit b8c48c5d72
No known key found for this signature in database
GPG Key ID: B5690EEEBB952194
13 changed files with 34 additions and 56 deletions

View File

@ -24,23 +24,16 @@ from vllm.triton_utils import tl
MNK_FACTORS = [
(1, 128, 128),
(1, 128, 2048),
(1, 512, 512),
(1, 1024, 128),
(1, 1024, 2048),
(32, 128, 128),
(32, 512, 512),
(32, 1024, 2048),
(45, 128, 128),
(45, 128, 2048),
(45, 512, 512),
(45, 1024, 128),
(45, 1024, 2048),
(64, 512, 512),
(64, 1024, 2048),
(222, 128, 128),
(222, 128, 2048),
(222, 1024, 128),
(222, 1024, 2048),
]
NUM_EXPERTS = [8, 64]
@ -117,10 +110,19 @@ def test_batched_mm(
block_shape: list[int] | None,
per_act_token_quant: bool,
):
"""Note: float8_e4m3fn is not supported on CUDA architecture < 89,
and those tests will be skipped on unsupported hardware."""
current_platform.seed_everything(7)
use_fp8_w8a8 = dtype == torch.float8_e4m3fn
if (dtype == torch.float8_e4m3fn) and not current_platform.has_device_capability(
89
):
pytest.skip(
"Triton limitation: fp8e4nv data type is not supported on CUDA arch < 89"
)
if (per_act_token_quant or block_shape is not None) and not use_fp8_w8a8:
pytest.skip("Don't test blocking for non-quantized types.")
@ -244,10 +246,19 @@ def test_fused_moe_batched_experts(
block_shape: list[int] | None,
input_scales: bool,
):
"""Note: float8_e4m3fn is not supported on CUDA architecture < 89,
and those tests will be skipped on unsupported hardware."""
current_platform.seed_everything(7)
use_fp8_w8a8 = dtype == torch.float8_e4m3fn
if (dtype == torch.float8_e4m3fn) and not current_platform.has_device_capability(
89
):
pytest.skip(
"Triton limitation: fp8e4nv data type is not supported on CUDA arch < 89"
)
if topk > e:
pytest.skip("topk > e")

View File

@ -42,57 +42,43 @@ DTYPES = [torch.bfloat16] # [torch.half, torch.bfloat16, torch.float32]
# and its hidden size is 7168.
MNK_FACTORS = [
(1, 128, 128),
(1, 512, 512),
(1, 128, 7168),
(1, 1024, 7168),
(1, 4608, 128),
(1, 4608, 512),
(1, 4608, 7168),
(83, 128, 128),
(83, 512, 512),
(83, 1024, 7168),
(83, 4608, 512),
(83, 4608, 7168),
(128, 128, 128),
(128, 512, 512),
(128, 1024, 7168),
(128, 4608, 512),
(128, 4608, 7168),
(2048, 128, 128),
(2048, 1024, 7168),
(2048, 4608, 512),
(2048, 4608, 7168),
(8192, 128, 128),
(8192, 512, 512),
(8192, 128, 7168),
(8192, 1024, 7168),
(8192, 4608, 512),
(8192, 4608, 7168),
]
MNK_FACTORS_DG = [
(128, 128, 128),
(128, 512, 512),
(128, 128, 7168),
(128, 1024, 7168),
(128, 4608, 128),
(128, 4608, 512),
(128, 4608, 7168),
(192, 128, 128),
(192, 512, 512),
(192, 1024, 7168),
(192, 4608, 512),
(192, 4608, 7168),
(1335, 128, 128),
(1335, 1024, 7168),
(1335, 4608, 512),
(1335, 4608, 7168),
(2048, 128, 128),
(2048, 512, 512),
(2048, 128, 7168),
(2048, 1024, 7168),
(2048, 4608, 128),
(2048, 4608, 512),
(2048, 4608, 7168),
]

View File

@ -21,36 +21,28 @@ vllm_config = VllmConfig()
vllm_config.scheduler_config.max_num_seqs = 128
vllm_config.scheduler_config.max_model_len = 8192
DTYPES = [torch.half, torch.bfloat16]
DTYPES = [torch.bfloat16]
MNK_FACTORS = [
(1, 128, 128),
(1, 512, 512),
(1, 128, 7168),
(1, 1024, 7168),
(1, 4096, 128),
(1, 4096, 512),
(1, 4096, 7168),
(33, 128, 128),
(33, 512, 512),
(33, 128, 7168),
(33, 1024, 7168),
(33, 4096, 128),
(33, 4096, 512),
(33, 4096, 7168),
(128, 128, 128),
(128, 512, 512),
(128, 1024, 7168),
(128, 4096, 512),
(128, 4096, 7168),
(222, 128, 128),
(222, 512, 512),
(222, 1024, 7168),
(222, 4096, 512),
(222, 4096, 7168),
(2048, 128, 128),
(2048, 1024, 7168),
(2048, 4096, 512),
(2048, 4096, 4096),
]

View File

@ -26,16 +26,13 @@ TOP_KS = [6, 8]
MNK_FACTORS = [
(2, 1024, 1024),
(2, 1024, 1536),
(2, 3072, 1024),
(2, 3072, 1536),
(7, 3072, 1536),
(64, 1024, 1024),
(64, 1024, 1536),
(64, 3072, 1024),
(64, 3072, 1536),
(224, 1024, 1024),
(224, 1024, 1536),
(224, 3072, 1024),
(224, 3072, 1536),
(32768, 1024, 1024),

View File

@ -393,7 +393,6 @@ def _test_deepep_deepgemm_moe(
MNKs = [
(8, 128, 128),
(8, 128, 512),
(8, 512, 512),
(3, 1024, 2048),
(32, 128, 1024),
(45, 512, 2048),

View File

@ -130,10 +130,8 @@ def run_single_case(m, n, k, topk, num_experts, block_size):
# Note: N <= 512 will disable the deepgemm path due to performance issues.
MNKs = [
(1024, 768, 128),
(1024, 768, 512),
(2048, 768, 512),
(512, 1024, 1024),
(512, 2048, 2048),
(4096, 4096, 1024),
]

View File

@ -34,8 +34,6 @@ TOP_KS = [1]
MNK_FACTORS = [
(256, 8192, 5120),
(256, 4096, 5120),
(127, 8192, 5120),
(127, 4096, 5120),
(10, 8192, 5120),
(10, 4096, 5120),

View File

@ -34,10 +34,8 @@ if not has_flashinfer_cutlass_fused_moe() or not current_platform.has_device_cap
MNK_FACTORS = [
(2, 1024, 1024),
(2, 1024, 1536),
(2, 3072, 1024),
(2, 3072, 1536),
(64, 1024, 1024),
(64, 1024, 1536),
(64, 3072, 1024),
(64, 2048, 1536),
@ -49,7 +47,7 @@ MNK_FACTORS = [
@pytest.mark.parametrize("m,n,k", MNK_FACTORS)
@pytest.mark.parametrize("e", [40, 64, 256])
@pytest.mark.parametrize("topk", [1, 6, 8])
@pytest.mark.parametrize("dtype", [torch.half, torch.bfloat16])
@pytest.mark.parametrize("dtype", [torch.bfloat16])
@torch.inference_mode()
def test_flashinfer_fp4_moe_no_graph(
m: int, n: int, k: int, e: int, topk: int, dtype: torch.dtype

View File

@ -27,7 +27,7 @@ from vllm.platforms import current_platform
@pytest.mark.parametrize("topk_group", [2])
@pytest.mark.parametrize("scoring_func", ["softmax", "sigmoid"])
@pytest.mark.parametrize("routed_scaling_factor", [1.0, 2.5])
@pytest.mark.parametrize("dtype", [torch.float16, torch.bfloat16, torch.float32])
@pytest.mark.parametrize("dtype", [torch.bfloat16, torch.float32])
def test_grouped_topk(
monkeypatch: pytest.MonkeyPatch,
n_token: int,

View File

@ -295,6 +295,8 @@ def test_modular_kernel_combinations_singlegpu(
world_size: int,
pytestconfig,
):
"""Note: float8_e4m3fn is not supported on CUDA architecture < 89,
and those tests will be skipped on unsupported hardware."""
config = Config(
Ms=Ms,
K=k,
@ -309,6 +311,12 @@ def test_modular_kernel_combinations_singlegpu(
world_size=world_size,
)
if (
quant_config is not None and quant_config.quant_dtype == torch.float8_e4m3fn
) and not current_platform.has_device_capability(89):
pytest.skip(
"Triton limitation: fp8e4nv data type is not supported on CUDA arch < 89"
)
verbosity = pytestconfig.getoption("verbose")
run(config, verbosity > 0)

View File

@ -66,8 +66,6 @@ FUSED_MOE_MNK_FACTORS = [
(1, 128, 128),
(1, 2048, 128),
(33, 2048, 128),
(222, 1024, 1024),
(32768, 128, 128),
(32768, 2048, 511),
(40000, 1024, 1024),
]
@ -76,7 +74,6 @@ FUSED_MOE_WN16_MNK_FACTORS = [
(1, 128, 128),
(1, 1024, 1024),
(32, 2048, 128),
(32, 1024, 1024),
(222, 2048, 1024),
]
@ -512,8 +509,8 @@ def marlin_moe_generate_valid_test_cases():
e_list = [4, 12]
topk_list = [2, 3]
ep_size_list = [1, 4]
dtype_list = [torch.half, torch.bfloat16]
group_size_list = [-1, 16, 32, 128]
dtype_list = [torch.bfloat16]
group_size_list = [-1, 32, 128]
act_order_list = [True, False]
quant_type_list = [
scalar_types.float4_e2m1f,
@ -885,10 +882,10 @@ def test_batched_moe_align_block_size_opcheck():
)
@pytest.mark.parametrize("m", [1, 33, 64, 222])
@pytest.mark.parametrize("m", [1, 33, 222])
@pytest.mark.parametrize("topk", TOP_KS)
@pytest.mark.parametrize("k", [128, 511, 1024])
@pytest.mark.parametrize("dtype", [torch.float32, torch.float16, torch.bfloat16])
@pytest.mark.parametrize("dtype", [torch.float32, torch.bfloat16])
@pytest.mark.skipif(current_platform.is_rocm(), reason="Skip for rocm")
def test_moe_sum(m: int, topk: int, k: int, dtype: torch.dtype):
input = torch.randn((m, topk, k), device="cuda", dtype=dtype)

View File

@ -26,9 +26,7 @@ MNK_FACTORS = [
(2, 1024, 1024),
(2, 1024, 1536),
(2, 3072, 1024),
(2, 3072, 1536),
(64, 1024, 1024),
(64, 1024, 1536),
(64, 3072, 1024),
(64, 2048, 1536),
(224, 1024, 1024),
@ -39,7 +37,7 @@ MNK_FACTORS = [
@pytest.mark.parametrize("m,n,k", MNK_FACTORS)
@pytest.mark.parametrize("e", [40, 64, 256])
@pytest.mark.parametrize("topk", [1, 6, 8])
@pytest.mark.parametrize("dtype", [torch.half, torch.bfloat16])
@pytest.mark.parametrize("dtype", [torch.bfloat16])
@torch.inference_mode()
def test_cutlass_fp4_moe_no_graph(
m: int, n: int, k: int, e: int, topk: int, dtype: torch.dtype

View File

@ -19,20 +19,16 @@ CASES = [
(32, 64, 256, fp8_dtype),
(17, 31, 768, fp8_dtype),
(1, 1, 128 * 1, fp8_dtype),
(1, 1, 128 * 2, fp8_dtype),
(1, 1, 128 * 3, fp8_dtype),
(1, 1, 128 * 4, fp8_dtype),
(8, 16, 128 * 1, fp8_dtype),
(8, 16, 128 * 2, fp8_dtype),
(8, 16, 128 * 3, fp8_dtype),
(8, 16, 128 * 4, fp8_dtype),
(8, 64, 7168, fp8_dtype),
(8, 128, 7168, fp8_dtype),
(8, 256, 7168, fp8_dtype),
(8, 512, 7168, fp8_dtype),
(8, 1024, 7168, fp8_dtype),
(256, 8, 7168, fp8_dtype),
(256, 16, 7168, fp8_dtype),
(256, 32, 7168, fp8_dtype),
(256, 64, 7168, fp8_dtype),
# Only add a few fnuz tests to help with long CI times.