vllm/tests/kernels/moe/test_batched_deepgemm.py
2025-09-17 17:43:31 -06:00

106 lines
3.4 KiB
Python

# SPDX-License-Identifier: Apache-2.0
# SPDX-FileCopyrightText: Copyright contributors to the vLLM project
import pytest
import torch
from vllm.model_executor.layers.fused_moe.batched_deep_gemm_moe import (
BatchedDeepGemmExperts)
from vllm.model_executor.layers.fused_moe.config import (
fp8_w8a8_moe_quant_config)
from vllm.model_executor.layers.fused_moe.fused_batched_moe import (
BatchedPrepareAndFinalize, BatchedTritonExperts)
from vllm.model_executor.layers.fused_moe.modular_kernel import (
FusedMoEModularKernel)
from vllm.utils.deep_gemm import calc_diff, is_deep_gemm_supported
from .test_deepgemm import make_block_quant_fp8_weights
BLOCK_SIZE = [128, 128]
@pytest.mark.skipif(not is_deep_gemm_supported(),
reason="Requires deep_gemm kernels")
@pytest.mark.parametrize("E", [16, 32]) # number of experts
@pytest.mark.parametrize("T", [256, 512]) # tokens per expert
@pytest.mark.parametrize("K", [128, 256]) # hidden dim
@pytest.mark.parametrize("N", [512, 1024]) # intermediate dim per expert
@pytest.mark.parametrize("topk", [2, 4])
def test_batched_deepgemm_vs_triton(E: int, T: int, K: int, N: int, topk: int,
monkeypatch):
"""Compare BatchedDeepGemmExperts to BatchedTritonExperts."""
monkeypatch.setenv("VLLM_USE_DEEP_GEMM", "1")
device = "cuda"
w1, w2, w1_s, w2_s = make_block_quant_fp8_weights(E, N, K, BLOCK_SIZE)
M = E * T # total tokens
a = torch.randn(M, K, device=device, dtype=torch.bfloat16) / 10.0
fp8_info = torch.finfo(torch.float8_e4m3fn)
a.clamp_(fp8_info.min, fp8_info.max)
# random router outputs → top-k indices / weights
router_logits = torch.randn(M, E, device=device, dtype=torch.float32)
topk_weights, topk_ids = torch.topk(router_logits, k=topk, dim=-1)
topk_weights = torch.nn.functional.softmax(topk_weights, dim=-1)
# token number for each expert
cnt = torch.bincount(topk_ids.flatten(), minlength=E)
max_cnt = int(cnt.max().item())
# next power of 2 for max token number
max_num_tokens = 1 << (max_cnt - 1).bit_length()
prep_finalize = BatchedPrepareAndFinalize(
max_num_tokens=max_num_tokens,
num_local_experts=E,
num_dispatchers=1,
rank=0,
)
quant_config = fp8_w8a8_moe_quant_config(
w1_scale=w1_s,
w2_scale=w2_s,
per_act_token_quant=False,
block_shape=BLOCK_SIZE,
)
# triton (reference)
triton_experts = BatchedTritonExperts(
max_num_tokens=max_num_tokens,
num_dispatchers=1,
quant_config=quant_config,
)
mk_triton = FusedMoEModularKernel(prep_finalize, triton_experts)
out_triton = mk_triton(
hidden_states=a,
w1=w1,
w2=w2,
topk_weights=topk_weights,
topk_ids=topk_ids,
inplace=False,
global_num_experts=E,
)
# deepgemm
deepgemm_experts = BatchedDeepGemmExperts(
max_num_tokens=max_num_tokens,
num_dispatchers=1,
quant_config=quant_config,
)
mk_deepgemm = FusedMoEModularKernel(prep_finalize, deepgemm_experts)
out_deepgemm = mk_deepgemm(
hidden_states=a,
w1=w1,
w2=w2,
topk_weights=topk_weights,
topk_ids=topk_ids,
inplace=False,
global_num_experts=E,
)
diff = calc_diff(out_deepgemm, out_triton)
assert diff < 1e-3, f"Output diff too large: {diff}"