From 6e8d8c4afbddf725b34ef938616701869f5b3462 Mon Sep 17 00:00:00 2001 From: Wentao Ye <44945378+yewentao256@users.noreply.github.com> Date: Fri, 1 Aug 2025 22:45:46 -0400 Subject: [PATCH] [Test] Add Unit Test for Batched DeepGEMM (#21559) Signed-off-by: yewentao256 --- tests/kernels/moe/test_batched_deepgemm.py | 103 +++++++++++++++++++++ tests/kernels/moe/test_deepgemm.py | 8 +- vllm/utils/deep_gemm.py | 4 +- 3 files changed, 107 insertions(+), 8 deletions(-) create mode 100644 tests/kernels/moe/test_batched_deepgemm.py diff --git a/tests/kernels/moe/test_batched_deepgemm.py b/tests/kernels/moe/test_batched_deepgemm.py new file mode 100644 index 0000000000000..018d4c224f75e --- /dev/null +++ b/tests/kernels/moe/test_batched_deepgemm.py @@ -0,0 +1,103 @@ +# SPDX-License-Identifier: Apache-2.0 +# SPDX-FileCopyrightText: Copyright contributors to the vLLM project + +import pytest +import torch + +from vllm.model_executor.layers.fused_moe.batched_deep_gemm_moe import ( + BatchedDeepGemmExperts) +from vllm.model_executor.layers.fused_moe.fused_batched_moe import ( + BatchedPrepareAndFinalize, BatchedTritonExperts) +from vllm.model_executor.layers.fused_moe.modular_kernel import ( + FusedMoEModularKernel) +from vllm.utils.deep_gemm import calc_diff, is_deep_gemm_supported + +from .test_deepgemm import make_block_quant_fp8_weights + +BLOCK_SIZE = [128, 128] + + +@pytest.mark.skipif(not is_deep_gemm_supported(), + reason="Requires deep_gemm kernels") +@pytest.mark.parametrize("E", [16, 32]) # number of experts +@pytest.mark.parametrize("T", [256, 512]) # tokens per expert +@pytest.mark.parametrize("K", [128, 256]) # hidden dim +@pytest.mark.parametrize("N", [512, 1024]) # intermediate dim per expert +@pytest.mark.parametrize("topk", [2, 4]) +def test_batched_deepgemm_vs_triton(E: int, T: int, K: int, N: int, topk: int, + monkeypatch): + """Compare BatchedDeepGemmExperts to BatchedTritonExperts.""" + + monkeypatch.setenv("VLLM_USE_DEEP_GEMM", "1") + + device = "cuda" + w1, w2, w1_s, w2_s = make_block_quant_fp8_weights(E, N, K, BLOCK_SIZE) + + M = E * T # total tokens + a = torch.randn(M, K, device=device, dtype=torch.bfloat16) / 10.0 + fp8_info = torch.finfo(torch.float8_e4m3fn) + a.clamp_(fp8_info.min, fp8_info.max) + + # random router outputs → top-k indices / weights + router_logits = torch.randn(M, E, device=device, dtype=torch.float32) + topk_weights, topk_ids = torch.topk(router_logits, k=topk, dim=-1) + topk_weights = torch.nn.functional.softmax(topk_weights, dim=-1) + + # token number for each expert + cnt = torch.bincount(topk_ids.flatten(), minlength=E) + max_cnt = int(cnt.max().item()) + # next power of 2 for max token number + max_num_tokens = 1 << (max_cnt - 1).bit_length() + + prep_finalize = BatchedPrepareAndFinalize( + max_num_tokens=max_num_tokens, + num_local_experts=E, + num_dispatchers=1, + rank=0, + ) + + # triton (reference) + triton_experts = BatchedTritonExperts( + max_num_tokens=max_num_tokens, + num_dispatchers=1, + use_fp8_w8a8=True, + per_act_token_quant=False, + block_shape=BLOCK_SIZE, + ) + mk_triton = FusedMoEModularKernel(prep_finalize, triton_experts) + + out_triton = mk_triton( + hidden_states=a, + w1=w1, + w2=w2, + topk_weights=topk_weights, + topk_ids=topk_ids, + inplace=False, + w1_scale=w1_s, + w2_scale=w2_s, + global_num_experts=E, + ) + + # deepgemm + deepgemm_experts = BatchedDeepGemmExperts( + max_num_tokens=max_num_tokens, + num_dispatchers=1, + block_shape=BLOCK_SIZE, + per_act_token_quant=False, + ) + mk_deepgemm = FusedMoEModularKernel(prep_finalize, deepgemm_experts) + + out_deepgemm = mk_deepgemm( + hidden_states=a, + w1=w1, + w2=w2, + topk_weights=topk_weights, + topk_ids=topk_ids, + inplace=False, + w1_scale=w1_s, + w2_scale=w2_s, + global_num_experts=E, + ) + + diff = calc_diff(out_deepgemm, out_triton) + assert diff < 1e-3, f"Output diff too large: {diff}" diff --git a/tests/kernels/moe/test_deepgemm.py b/tests/kernels/moe/test_deepgemm.py index b6ea4ee2324c9..b2b78662c9ded 100644 --- a/tests/kernels/moe/test_deepgemm.py +++ b/tests/kernels/moe/test_deepgemm.py @@ -20,11 +20,6 @@ from vllm.utils.deep_gemm import (calc_diff, is_deep_gemm_supported, BLOCK_SIZE = [128, 128] -requires_deep_gemm = pytest.mark.skipif( - not is_deep_gemm_supported(), - reason="Requires deep_gemm kernels", -) - def make_block_quant_fp8_weights( e: int, @@ -152,7 +147,8 @@ NUM_EXPERTS = [32] @pytest.mark.parametrize("mnk", MNKs) @pytest.mark.parametrize("topk", TOPKS) @pytest.mark.parametrize("num_experts", NUM_EXPERTS) -@requires_deep_gemm +@pytest.mark.skipif(not is_deep_gemm_supported(), + reason="Requires deep_gemm kernels") def test_deepgemm_vs_triton(mnk, topk, num_experts, monkeypatch): with monkeypatch.context() as m: diff --git a/vllm/utils/deep_gemm.py b/vllm/utils/deep_gemm.py index 8ab34e7505ee2..0edfb01cde9d6 100644 --- a/vllm/utils/deep_gemm.py +++ b/vllm/utils/deep_gemm.py @@ -23,10 +23,10 @@ def is_deep_gemm_supported() -> bool: """Return ``True`` if DeepGEMM is supported on the current platform. Currently, only Hopper and Blackwell GPUs are supported. """ - supported_arch = current_platform.is_cuda() and ( + is_supported_arch = current_platform.is_cuda() and ( current_platform.is_device_capability(90) or current_platform.is_device_capability(100)) - return has_deep_gemm() and supported_arch + return has_deep_gemm() and is_supported_arch @functools.cache