mirror of
https://git.datalinker.icu/vllm-project/vllm.git
synced 2025-12-15 18:25:45 +08:00
[Test] Add Unit Test for Batched DeepGEMM (#21559)
Signed-off-by: yewentao256 <zhyanwentao@126.com>
This commit is contained in:
parent
8d524ce79f
commit
6e8d8c4afb
103
tests/kernels/moe/test_batched_deepgemm.py
Normal file
103
tests/kernels/moe/test_batched_deepgemm.py
Normal file
@ -0,0 +1,103 @@
|
|||||||
|
# SPDX-License-Identifier: Apache-2.0
|
||||||
|
# SPDX-FileCopyrightText: Copyright contributors to the vLLM project
|
||||||
|
|
||||||
|
import pytest
|
||||||
|
import torch
|
||||||
|
|
||||||
|
from vllm.model_executor.layers.fused_moe.batched_deep_gemm_moe import (
|
||||||
|
BatchedDeepGemmExperts)
|
||||||
|
from vllm.model_executor.layers.fused_moe.fused_batched_moe import (
|
||||||
|
BatchedPrepareAndFinalize, BatchedTritonExperts)
|
||||||
|
from vllm.model_executor.layers.fused_moe.modular_kernel import (
|
||||||
|
FusedMoEModularKernel)
|
||||||
|
from vllm.utils.deep_gemm import calc_diff, is_deep_gemm_supported
|
||||||
|
|
||||||
|
from .test_deepgemm import make_block_quant_fp8_weights
|
||||||
|
|
||||||
|
BLOCK_SIZE = [128, 128]
|
||||||
|
|
||||||
|
|
||||||
|
@pytest.mark.skipif(not is_deep_gemm_supported(),
|
||||||
|
reason="Requires deep_gemm kernels")
|
||||||
|
@pytest.mark.parametrize("E", [16, 32]) # number of experts
|
||||||
|
@pytest.mark.parametrize("T", [256, 512]) # tokens per expert
|
||||||
|
@pytest.mark.parametrize("K", [128, 256]) # hidden dim
|
||||||
|
@pytest.mark.parametrize("N", [512, 1024]) # intermediate dim per expert
|
||||||
|
@pytest.mark.parametrize("topk", [2, 4])
|
||||||
|
def test_batched_deepgemm_vs_triton(E: int, T: int, K: int, N: int, topk: int,
|
||||||
|
monkeypatch):
|
||||||
|
"""Compare BatchedDeepGemmExperts to BatchedTritonExperts."""
|
||||||
|
|
||||||
|
monkeypatch.setenv("VLLM_USE_DEEP_GEMM", "1")
|
||||||
|
|
||||||
|
device = "cuda"
|
||||||
|
w1, w2, w1_s, w2_s = make_block_quant_fp8_weights(E, N, K, BLOCK_SIZE)
|
||||||
|
|
||||||
|
M = E * T # total tokens
|
||||||
|
a = torch.randn(M, K, device=device, dtype=torch.bfloat16) / 10.0
|
||||||
|
fp8_info = torch.finfo(torch.float8_e4m3fn)
|
||||||
|
a.clamp_(fp8_info.min, fp8_info.max)
|
||||||
|
|
||||||
|
# random router outputs → top-k indices / weights
|
||||||
|
router_logits = torch.randn(M, E, device=device, dtype=torch.float32)
|
||||||
|
topk_weights, topk_ids = torch.topk(router_logits, k=topk, dim=-1)
|
||||||
|
topk_weights = torch.nn.functional.softmax(topk_weights, dim=-1)
|
||||||
|
|
||||||
|
# token number for each expert
|
||||||
|
cnt = torch.bincount(topk_ids.flatten(), minlength=E)
|
||||||
|
max_cnt = int(cnt.max().item())
|
||||||
|
# next power of 2 for max token number
|
||||||
|
max_num_tokens = 1 << (max_cnt - 1).bit_length()
|
||||||
|
|
||||||
|
prep_finalize = BatchedPrepareAndFinalize(
|
||||||
|
max_num_tokens=max_num_tokens,
|
||||||
|
num_local_experts=E,
|
||||||
|
num_dispatchers=1,
|
||||||
|
rank=0,
|
||||||
|
)
|
||||||
|
|
||||||
|
# triton (reference)
|
||||||
|
triton_experts = BatchedTritonExperts(
|
||||||
|
max_num_tokens=max_num_tokens,
|
||||||
|
num_dispatchers=1,
|
||||||
|
use_fp8_w8a8=True,
|
||||||
|
per_act_token_quant=False,
|
||||||
|
block_shape=BLOCK_SIZE,
|
||||||
|
)
|
||||||
|
mk_triton = FusedMoEModularKernel(prep_finalize, triton_experts)
|
||||||
|
|
||||||
|
out_triton = mk_triton(
|
||||||
|
hidden_states=a,
|
||||||
|
w1=w1,
|
||||||
|
w2=w2,
|
||||||
|
topk_weights=topk_weights,
|
||||||
|
topk_ids=topk_ids,
|
||||||
|
inplace=False,
|
||||||
|
w1_scale=w1_s,
|
||||||
|
w2_scale=w2_s,
|
||||||
|
global_num_experts=E,
|
||||||
|
)
|
||||||
|
|
||||||
|
# deepgemm
|
||||||
|
deepgemm_experts = BatchedDeepGemmExperts(
|
||||||
|
max_num_tokens=max_num_tokens,
|
||||||
|
num_dispatchers=1,
|
||||||
|
block_shape=BLOCK_SIZE,
|
||||||
|
per_act_token_quant=False,
|
||||||
|
)
|
||||||
|
mk_deepgemm = FusedMoEModularKernel(prep_finalize, deepgemm_experts)
|
||||||
|
|
||||||
|
out_deepgemm = mk_deepgemm(
|
||||||
|
hidden_states=a,
|
||||||
|
w1=w1,
|
||||||
|
w2=w2,
|
||||||
|
topk_weights=topk_weights,
|
||||||
|
topk_ids=topk_ids,
|
||||||
|
inplace=False,
|
||||||
|
w1_scale=w1_s,
|
||||||
|
w2_scale=w2_s,
|
||||||
|
global_num_experts=E,
|
||||||
|
)
|
||||||
|
|
||||||
|
diff = calc_diff(out_deepgemm, out_triton)
|
||||||
|
assert diff < 1e-3, f"Output diff too large: {diff}"
|
||||||
@ -20,11 +20,6 @@ from vllm.utils.deep_gemm import (calc_diff, is_deep_gemm_supported,
|
|||||||
|
|
||||||
BLOCK_SIZE = [128, 128]
|
BLOCK_SIZE = [128, 128]
|
||||||
|
|
||||||
requires_deep_gemm = pytest.mark.skipif(
|
|
||||||
not is_deep_gemm_supported(),
|
|
||||||
reason="Requires deep_gemm kernels",
|
|
||||||
)
|
|
||||||
|
|
||||||
|
|
||||||
def make_block_quant_fp8_weights(
|
def make_block_quant_fp8_weights(
|
||||||
e: int,
|
e: int,
|
||||||
@ -152,7 +147,8 @@ NUM_EXPERTS = [32]
|
|||||||
@pytest.mark.parametrize("mnk", MNKs)
|
@pytest.mark.parametrize("mnk", MNKs)
|
||||||
@pytest.mark.parametrize("topk", TOPKS)
|
@pytest.mark.parametrize("topk", TOPKS)
|
||||||
@pytest.mark.parametrize("num_experts", NUM_EXPERTS)
|
@pytest.mark.parametrize("num_experts", NUM_EXPERTS)
|
||||||
@requires_deep_gemm
|
@pytest.mark.skipif(not is_deep_gemm_supported(),
|
||||||
|
reason="Requires deep_gemm kernels")
|
||||||
def test_deepgemm_vs_triton(mnk, topk, num_experts, monkeypatch):
|
def test_deepgemm_vs_triton(mnk, topk, num_experts, monkeypatch):
|
||||||
|
|
||||||
with monkeypatch.context() as m:
|
with monkeypatch.context() as m:
|
||||||
|
|||||||
@ -23,10 +23,10 @@ def is_deep_gemm_supported() -> bool:
|
|||||||
"""Return ``True`` if DeepGEMM is supported on the current platform.
|
"""Return ``True`` if DeepGEMM is supported on the current platform.
|
||||||
Currently, only Hopper and Blackwell GPUs are supported.
|
Currently, only Hopper and Blackwell GPUs are supported.
|
||||||
"""
|
"""
|
||||||
supported_arch = current_platform.is_cuda() and (
|
is_supported_arch = current_platform.is_cuda() and (
|
||||||
current_platform.is_device_capability(90)
|
current_platform.is_device_capability(90)
|
||||||
or current_platform.is_device_capability(100))
|
or current_platform.is_device_capability(100))
|
||||||
return has_deep_gemm() and supported_arch
|
return has_deep_gemm() and is_supported_arch
|
||||||
|
|
||||||
|
|
||||||
@functools.cache
|
@functools.cache
|
||||||
|
|||||||
Loading…
x
Reference in New Issue
Block a user