vllm/tests/kernels/moe/test_deepep_deepgemm_moe.py
Fardin Hoque b8c48c5d72
kernels/moe test pruning (#27053)
Signed-off-by: Fardin Hoque <kfhfar@amazon.com>
Co-authored-by: Wentao Ye <44945378+yewentao256@users.noreply.github.com>
2025-10-30 12:10:34 +08:00

538 lines
14 KiB
Python

# SPDX-License-Identifier: Apache-2.0
# SPDX-FileCopyrightText: Copyright contributors to the vLLM project
"""
Test DeepEP + DeepGEMM integration
DeepGEMM are gemm kernels specialized for the
fp8 block-quantized case.
"""
import dataclasses
import pytest
import torch.distributed
from torch.distributed import ProcessGroup
from typing_extensions import ParamSpec
from vllm.config import VllmConfig, set_current_vllm_config
from vllm.model_executor.layers.fused_moe.config import (
FusedMoEQuantConfig,
fp8_w8a8_moe_quant_config,
)
from vllm.model_executor.layers.fused_moe.fused_moe import fused_experts
from vllm.model_executor.layers.fused_moe.modular_kernel import FusedMoEModularKernel
from vllm.platforms import current_platform
from vllm.utils.deep_gemm import is_deep_gemm_e8m0_used, is_deep_gemm_supported
from vllm.utils.import_utils import has_deep_ep, has_deep_gemm
from ...utils import multi_gpu_test
from .parallel_utils import ProcessGroupInfo, parallel_launch
from .utils import make_test_weights
if has_deep_ep():
from vllm.model_executor.layers.fused_moe.deepep_ht_prepare_finalize import (
DeepEPHTPrepareAndFinalize,
)
from vllm.model_executor.layers.fused_moe.deepep_ll_prepare_finalize import (
DeepEPLLPrepareAndFinalize,
)
from .parallel_utils import DeepEPHTArgs, DeepEPLLArgs, make_deepep_a2a
if has_deep_gemm():
from vllm.model_executor.layers.fused_moe.batched_deep_gemm_moe import (
BatchedDeepGemmExperts,
)
from vllm.model_executor.layers.fused_moe.deep_gemm_moe import DeepGemmExperts
requires_deep_ep = pytest.mark.skipif(
not has_deep_ep(),
reason="Requires deep_ep kernels",
)
requires_deep_gemm = pytest.mark.skipif(
not is_deep_gemm_supported(),
reason="Requires deep_gemm kernels",
)
P = ParamSpec("P")
def next_power_of_2(x):
import math
if x == 0:
return 1
return 2 ** math.ceil(math.log2(x))
def make_block_quant_fp8_weights(
e: int,
n: int,
k: int,
block_size: list[int],
) -> tuple[torch.Tensor, torch.Tensor, torch.Tensor, torch.Tensor]:
"""
Return weights w1q, w2q, w1_scale, w2_scale
"""
(_, w1q, w1_scale, _), (_, w2q, w2_scale, _) = make_test_weights(
e, n, k, torch.bfloat16, torch.float8_e4m3fn, block_shape=block_size
)
return w1q, w2q, w1_scale, w2_scale
@dataclasses.dataclass
class TestConfig:
topk: int
m: int
k: int
n: int
num_experts: int
per_act_token_quant: bool
block_size: list[int]
# configs for testing low-latency kernels
low_latency: bool
use_fp8_dispatch: bool | None = False
@dataclasses.dataclass
class TestTensors:
rank_tokens: torch.Tensor # all ranks make this many tokens
rank_token_scales: torch.Tensor | None
topk: torch.Tensor
topk_weights: torch.Tensor
config: TestConfig
@staticmethod
def make(config: TestConfig, rank) -> "TestTensors":
dtype = torch.bfloat16
topk, m, k = (config.topk, config.m, config.k)
fp8_info = torch.finfo(torch.float8_e4m3fn)
fp8_max, fp8_min = fp8_info.max, fp8_info.min
rank_tokens = (
torch.randn((m, k), device=torch.cuda.current_device(), dtype=dtype) / 10.0
)
rank_tokens = rank_tokens.clamp(min=fp8_min, max=fp8_max)
rank_token_scales = None
topk_ids = torch.randint(
low=0,
high=config.num_experts,
size=(m, topk),
device=torch.cuda.current_device(),
).to(dtype=torch.int64)
topk_weights = torch.randn(
topk_ids.shape, dtype=torch.float32, device=torch.cuda.current_device()
)
return TestTensors(
rank_tokens=rank_tokens,
rank_token_scales=rank_token_scales,
topk=topk_ids,
topk_weights=topk_weights,
config=config,
)
def make_ll_modular_kernel(
pg: ProcessGroup,
pgi: ProcessGroupInfo,
max_tokens_per_rank: int,
dp_size: int,
hidden_size: int,
q_dtype: torch.dtype | None,
test_config: TestConfig,
quant_config: FusedMoEQuantConfig,
) -> FusedMoEModularKernel:
assert test_config.low_latency
assert test_config.use_fp8_dispatch is not None
a2a: DeepEPLLPrepareAndFinalize = make_deepep_a2a(
pg=pg,
pgi=pgi,
dp_size=dp_size,
deepep_ht_args=None,
deepep_ll_args=DeepEPLLArgs(
max_tokens_per_rank=max_tokens_per_rank,
hidden_size=hidden_size,
num_experts=test_config.num_experts,
use_fp8_dispatch=test_config.use_fp8_dispatch,
),
q_dtype=q_dtype,
block_shape=test_config.block_size,
)
fused_experts = BatchedDeepGemmExperts(
max_num_tokens=max_tokens_per_rank,
num_dispatchers=pgi.world_size // dp_size,
quant_config=quant_config,
)
mk = FusedMoEModularKernel(prepare_finalize=a2a, fused_experts=fused_experts)
return mk
def make_ht_modular_kernel(
pg: ProcessGroup,
pgi: ProcessGroupInfo,
dp_size: int,
num_local_experts: int,
q_dtype: torch.dtype | None,
test_config: TestConfig,
quant_config: FusedMoEQuantConfig,
) -> FusedMoEModularKernel:
assert not test_config.low_latency
assert test_config.use_fp8_dispatch is None
a2a: DeepEPHTPrepareAndFinalize = make_deepep_a2a(
pg=pg,
pgi=pgi,
dp_size=dp_size,
deepep_ht_args=DeepEPHTArgs(num_local_experts=num_local_experts),
deepep_ll_args=None,
q_dtype=q_dtype,
block_shape=test_config.block_size,
)
fused_experts = DeepGemmExperts(quant_config)
mk = FusedMoEModularKernel(prepare_finalize=a2a, fused_experts=fused_experts)
return mk
def make_modular_kernel(
pg: ProcessGroup,
pgi: ProcessGroupInfo,
dp_size: int,
num_local_experts: int,
test_tensors: TestTensors,
quant_config: FusedMoEQuantConfig,
) -> FusedMoEModularKernel:
q_dtype = torch.float8_e4m3fn
test_config = test_tensors.config
mk: FusedMoEModularKernel
# Make modular kernel
if test_config.low_latency:
max_tokens_per_rank = max(64, next_power_of_2(test_tensors.rank_tokens.size(0)))
hidden_size = test_tensors.rank_tokens.size(-1)
mk = make_ll_modular_kernel(
pg=pg,
pgi=pgi,
max_tokens_per_rank=max_tokens_per_rank,
dp_size=dp_size,
hidden_size=hidden_size,
q_dtype=q_dtype,
test_config=test_config,
quant_config=quant_config,
)
else:
mk = make_ht_modular_kernel(
pg,
pgi,
dp_size,
num_local_experts,
q_dtype,
test_config,
quant_config=quant_config,
)
return mk
def deepep_deepgemm_moe_impl(
pg: ProcessGroup,
pgi: ProcessGroupInfo,
dp_size: int,
test_tensors: TestTensors,
w1: torch.Tensor,
w2: torch.Tensor,
w1_scale: torch.Tensor | None,
w2_scale: torch.Tensor | None,
) -> torch.Tensor:
test_config = test_tensors.config
num_experts = test_config.num_experts
num_local_experts = w1.size(0)
def build_expert_map():
num_local_experts = w1.size(0)
expert_map = torch.full((num_experts,), fill_value=-1, dtype=torch.int32)
s = pgi.rank * num_local_experts
e = s + num_local_experts
expert_map[s:e] = torch.tensor(list(range(num_local_experts)))
return expert_map.to(device=torch.cuda.current_device(), dtype=torch.int32)
quant_config = fp8_w8a8_moe_quant_config(
w1_scale=w1_scale,
w2_scale=w2_scale,
# Low-Latency kernels can't dispatch scales.
a1_scale=(None if test_config.low_latency else test_tensors.rank_token_scales),
block_shape=test_config.block_size,
)
# Make modular kernel
mk: FusedMoEModularKernel = make_modular_kernel(
pg=pg,
pgi=pgi,
dp_size=dp_size,
num_local_experts=num_local_experts,
test_tensors=test_tensors,
quant_config=quant_config,
)
out = mk.forward(
hidden_states=test_tensors.rank_tokens,
w1=w1,
w2=w2,
topk_weights=test_tensors.topk_weights,
topk_ids=test_tensors.topk,
inplace=False,
activation="silu",
global_num_experts=num_experts,
expert_map=build_expert_map(),
apply_router_weight_on_input=False,
)
return out
def triton_impl(
a: torch.Tensor,
topk_ids: torch.Tensor,
topk_weights: torch.Tensor,
w1: torch.Tensor,
w2: torch.Tensor,
w1_scale: torch.Tensor,
w2_scale: torch.Tensor,
a1_scale: torch.Tensor,
block_shape: list[int],
):
quant_config = fp8_w8a8_moe_quant_config(
w1_scale=w1_scale,
w2_scale=w2_scale,
a1_scale=a1_scale,
block_shape=block_shape,
)
return fused_experts(
hidden_states=a,
w1=w1,
w2=w2,
topk_weights=topk_weights,
topk_ids=topk_ids,
inplace=False,
quant_config=quant_config,
# Make sure this is set to False so we
# don't end up comparing the same implementation.
allow_deep_gemm=False,
)
def _test_deepep_deepgemm_moe(
pgi: ProcessGroupInfo,
dp_size: int,
config: TestConfig,
w1: torch.Tensor,
w2: torch.Tensor,
w1_scale: torch.Tensor,
w2_scale: torch.Tensor,
):
current_platform.seed_everything(pgi.rank)
w1 = w1.to(device=torch.cuda.current_device())
w2 = w2.to(device=torch.cuda.current_device())
w1_scale = w1_scale.to(device=torch.cuda.current_device())
w2_scale = w2_scale.to(device=torch.cuda.current_device())
pg = torch.distributed.new_group(list(range(pgi.world_size)))
test_tensors = TestTensors.make(config, pgi.rank)
block_shape = [w1.size(1) // w1_scale.size(1), w1.size(2) // w1_scale.size(2)]
with set_current_vllm_config(VllmConfig()):
# Reference
triton_moe = triton_impl(
a=test_tensors.rank_tokens,
topk_ids=test_tensors.topk,
topk_weights=test_tensors.topk_weights,
w1=w1,
w2=w2,
w1_scale=w1_scale,
w2_scale=w2_scale,
a1_scale=test_tensors.rank_token_scales,
block_shape=block_shape,
)
# Slice experts for this rank.
num_local_experts = config.num_experts // pgi.world_size
e_start = num_local_experts * pgi.rank
e_end = e_start + num_local_experts
w1_ep = w1[e_start:e_end]
w2_ep = w2[e_start:e_end]
w1_scale_ep = w1_scale[e_start:e_end]
w2_scale_ep = w2_scale[e_start:e_end]
deepep_moe = deepep_deepgemm_moe_impl(
pg,
pgi,
dp_size,
test_tensors,
w1_ep,
w2_ep,
w1_scale_ep,
w2_scale_ep,
)
torch.testing.assert_close(
triton_moe,
deepep_moe,
atol=6e-2,
rtol=6e-2,
)
MNKs = [
(8, 128, 128),
(8, 128, 512),
(3, 1024, 2048),
(32, 128, 1024),
(45, 512, 2048),
(64, 1024, 1024),
(129, 128, 256),
(129, 1024, 2048),
(222, 1024, 2048),
]
TOPKS = [2, 6]
NUM_EXPERTS = [32]
@pytest.mark.parametrize("mnk", MNKs)
@pytest.mark.parametrize("num_experts", NUM_EXPERTS)
@pytest.mark.parametrize("topk", TOPKS)
@pytest.mark.parametrize("world_dp_size", [(2, 1)])
@multi_gpu_test(num_gpus=2)
@requires_deep_ep
@requires_deep_gemm
@pytest.mark.skipif(
is_deep_gemm_e8m0_used(), reason="Skipping test for Blackwell DeepGEMM"
)
def test_ht_deepep_deepgemm_moe(
mnk: tuple[int, int, int],
num_experts: int,
topk: int,
world_dp_size: tuple[int, int],
):
"""
Tests for High-Throughput DeepEP + DeepGemm integration.
"""
import deep_gemm
m, n, k = mnk
current_platform.seed_everything(7)
if topk > num_experts:
pytest.skip(f"Skipping test: topk={topk} > E={num_experts}")
block_m = deep_gemm.get_m_alignment_for_contiguous_layout()
block_size = [block_m, block_m]
world_size, dp_size = world_dp_size
config = TestConfig(
topk=topk,
m=m,
k=k,
n=n,
num_experts=num_experts,
per_act_token_quant=False,
block_size=block_size,
low_latency=False,
use_fp8_dispatch=None,
)
w1, w2, w1_scale, w2_scale = make_block_quant_fp8_weights(
num_experts, n, k, block_size
)
parallel_launch(
world_size,
_test_deepep_deepgemm_moe,
dp_size,
config,
w1,
w2,
w1_scale,
w2_scale,
)
MNKs = [
(1, 128, 2560),
(2, 128, 2560),
(3, 1024, 2560),
(32, 128, 2560),
(45, 512, 2560),
(64, 1024, 2560),
(222, 1024, 2560),
]
# Fix tests for USE_FP8_DISPATCH=True
USE_FP8_DISPATCH = [False]
@pytest.mark.parametrize("mnk", MNKs)
@pytest.mark.parametrize("num_experts", NUM_EXPERTS)
@pytest.mark.parametrize("topk", TOPKS)
@pytest.mark.parametrize("use_fp8_dispatch", USE_FP8_DISPATCH)
@pytest.mark.parametrize("block_size", [[128, 128]])
@pytest.mark.parametrize("world_dp_size", [(2, 1)])
@multi_gpu_test(num_gpus=2)
@requires_deep_ep
@requires_deep_gemm
@pytest.mark.skipif(
is_deep_gemm_e8m0_used(), reason="Skipping test for Blackwell DeepGEMM"
)
def test_ll_deepep_deepgemm_moe(
mnk: tuple[int, int, int],
num_experts: int,
topk: int,
use_fp8_dispatch: bool,
block_size: list[int],
world_dp_size: tuple[int, int],
):
"""
Tests for Low-Latency DeepEP + DeepGemm integration.
"""
m, n, k = mnk
current_platform.seed_everything(7)
if topk > num_experts:
pytest.skip(f"Skipping test: topk={topk} > E={num_experts}")
world_size, dp_size = world_dp_size
config = TestConfig(
topk=topk,
m=m,
k=k,
n=n,
num_experts=num_experts,
per_act_token_quant=False,
block_size=block_size,
low_latency=True,
use_fp8_dispatch=use_fp8_dispatch,
)
w1, w2, w1_scale, w2_scale = make_block_quant_fp8_weights(
num_experts, n, k, block_size
)
parallel_launch(
world_size,
_test_deepep_deepgemm_moe,
dp_size,
config,
w1,
w2,
w1_scale,
w2_scale,
)