mirror of
https://git.datalinker.icu/vllm-project/vllm.git
synced 2025-12-18 00:45:01 +08:00
Signed-off-by: Xin Yang <xyangx@amazon.com> Signed-off-by: Xin Yang <105740670+xyang16@users.noreply.github.com> Co-authored-by: Jee Jee Li <pandaleefree@gmail.com>
250 lines
7.3 KiB
Python
250 lines
7.3 KiB
Python
# SPDX-License-Identifier: Apache-2.0
|
|
# SPDX-FileCopyrightText: Copyright contributors to the vLLM project
|
|
"""
|
|
Test modular OAI Triton MoE
|
|
"""
|
|
|
|
import pytest
|
|
import torch
|
|
|
|
from vllm.utils.import_utils import has_triton_kernels
|
|
|
|
if not has_triton_kernels():
|
|
pytest.skip(
|
|
"triton_kernels not found, skipping all related tests",
|
|
allow_module_level=True,
|
|
)
|
|
|
|
from triton_kernels.matmul_ogs import FlexCtx, PrecisionConfig
|
|
from triton_kernels.numerics import InFlexData
|
|
from triton_kernels.numerics_details.mxfp import downcast_to_mxfp, upcast_from_mxfp
|
|
from triton_kernels.tensor import FP4, convert_layout, wrap_torch_tensor
|
|
from triton_kernels.tensor_details import layout
|
|
from triton_kernels.testing import assert_close
|
|
|
|
from vllm.config import VllmConfig, set_current_vllm_config
|
|
from vllm.model_executor.layers.fused_moe.config import mxfp4_w4a16_moe_quant_config
|
|
from vllm.model_executor.layers.fused_moe.gpt_oss_triton_kernels_moe import (
|
|
OAITritonExperts,
|
|
UnfusedOAITritonExperts,
|
|
)
|
|
from vllm.model_executor.layers.fused_moe.modular_kernel import FusedMoEModularKernel
|
|
from vllm.model_executor.layers.fused_moe.prepare_finalize import (
|
|
MoEPrepareAndFinalizeNoEP,
|
|
)
|
|
from vllm.model_executor.layers.utils import shuffle_weight
|
|
from vllm.platforms import current_platform
|
|
|
|
MNK = [
|
|
(1, 512, 384),
|
|
(1, 2880, 2880),
|
|
(2, 512, 384),
|
|
(2, 2880, 2880),
|
|
(16, 2880, 2880),
|
|
]
|
|
|
|
|
|
def unshuffle_weight(w: torch.Tensor):
|
|
first = w[..., ::2]
|
|
second = w[..., 1::2]
|
|
return torch.concat((first, second), dim=-1)
|
|
|
|
|
|
def make_weights(dtype, k, n, e):
|
|
w1 = torch.randn((e, k, 2 * n), dtype=dtype, device="cuda")
|
|
w1_bias = torch.randn((e, 2 * n), dtype=dtype, device="cuda")
|
|
|
|
w2 = torch.randn((e, n, k), dtype=dtype, device="cuda")
|
|
w2_bias = torch.randn((e, k), dtype=dtype, device="cuda")
|
|
|
|
w1_tri = w1.clone()
|
|
w2_tri = w2.clone()
|
|
|
|
w1_bias_tri = w1_bias.clone()
|
|
w2_bias_tri = w2_bias.clone()
|
|
w1_bias_tri = w1_bias_tri.to(torch.float32)
|
|
w2_bias_tri = w2_bias_tri.to(torch.float32)
|
|
|
|
# shuffle weights
|
|
w1_tri = shuffle_weight(w1_tri)
|
|
w1_bias_tri = shuffle_weight(w1_bias_tri)
|
|
|
|
# quant triton_weights
|
|
w1_tri, w1_scale_tri = downcast_to_mxfp(w1_tri, torch.uint8, axis=1)
|
|
w1 = upcast_from_mxfp(w1_tri, w1_scale_tri, dtype, axis=1)
|
|
w1 = unshuffle_weight(w1)
|
|
|
|
w2_tri, w2_scale_tri = downcast_to_mxfp(w2_tri, torch.uint8, axis=1)
|
|
w2 = upcast_from_mxfp(w2_tri, w2_scale_tri, dtype, axis=1)
|
|
|
|
num_warps = 8
|
|
w_layout, w_layout_opts = layout.make_default_matmul_mxfp4_w_layout(mx_axis=1)
|
|
w_scale_layout, w_scale_layout_opts = (
|
|
layout.make_default_matmul_mxfp4_w_scale_layout(mx_axis=1, num_warps=num_warps)
|
|
)
|
|
|
|
w1_tri = convert_layout(wrap_torch_tensor(w1_tri, FP4), w_layout, **w_layout_opts)
|
|
w1_scale_tri = convert_layout(
|
|
wrap_torch_tensor(w1_scale_tri),
|
|
w_scale_layout,
|
|
**w_scale_layout_opts,
|
|
)
|
|
|
|
w2_tri = convert_layout(wrap_torch_tensor(w2_tri, FP4), w_layout, **w_layout_opts)
|
|
w2_scale_tri = convert_layout(
|
|
wrap_torch_tensor(w2_scale_tri),
|
|
w_scale_layout,
|
|
**w_scale_layout_opts,
|
|
)
|
|
|
|
w1_precision_config = PrecisionConfig(
|
|
weight_scale=w1_scale_tri, flex_ctx=FlexCtx(rhs_data=InFlexData())
|
|
)
|
|
w2_precision_config = PrecisionConfig(
|
|
weight_scale=w2_scale_tri, flex_ctx=FlexCtx(rhs_data=InFlexData())
|
|
)
|
|
|
|
return (
|
|
w1,
|
|
w2,
|
|
w1_bias,
|
|
w2_bias,
|
|
w1_tri,
|
|
w2_tri,
|
|
w1_bias_tri,
|
|
w2_bias_tri,
|
|
w1_precision_config,
|
|
w2_precision_config,
|
|
)
|
|
|
|
|
|
def swiglu(x, alpha: float = 1.702, limit: float = 1.0):
|
|
# Note we add an extra bias of 1 to the linear layer
|
|
x_glu, x_linear = torch.chunk(x, 2, dim=-1)
|
|
if limit is not None:
|
|
x_glu = x_glu.clamp(max=limit)
|
|
out_glu = x_glu * torch.sigmoid(alpha * x_glu)
|
|
if limit is not None:
|
|
x_linear = x_linear.clamp(min=-limit, max=limit)
|
|
return out_glu * (x_linear + 1)
|
|
|
|
|
|
def torch_moe_impl(
|
|
hidden_states: torch.Tensor, # (M, K)
|
|
w1: torch.Tensor, # (E, K, 2N)
|
|
w2: torch.Tensor, # (E, N, K)
|
|
w1_bias: torch.Tensor, # (E, 2N)
|
|
w2_bias: torch.Tensor, # (E, K)
|
|
topk_weights: torch.Tensor, # (M, topk)
|
|
topk_ids: torch.Tensor, # (M, topk)
|
|
):
|
|
w1 = w1[topk_ids, ...]
|
|
w1_bias = w1_bias[topk_ids, ...]
|
|
hidden_states = torch.einsum("bekc,bk->bec", w1, hidden_states) + w1_bias
|
|
hidden_states = swiglu(hidden_states, limit=7)
|
|
|
|
w2 = w2[topk_ids, ...]
|
|
w2_bias = w2_bias[topk_ids, ...]
|
|
hidden_states = torch.einsum("bekc,bek->bec", w2, hidden_states) + w2_bias
|
|
|
|
# Weighted sum of experts
|
|
hidden_states = torch.einsum("bec,be->bc", hidden_states, topk_weights)
|
|
return hidden_states
|
|
|
|
|
|
def oai_triton_moe_impl(
|
|
x: torch.Tensor,
|
|
w1: torch.Tensor,
|
|
w2: torch.Tensor,
|
|
w1_scale: "PrecisionConfig",
|
|
w2_scale: "PrecisionConfig",
|
|
w1_bias: torch.Tensor | None,
|
|
w2_bias: torch.Tensor | None,
|
|
num_experts: int,
|
|
topk_weights: torch.Tensor,
|
|
topk_ids: torch.Tensor,
|
|
unfused: bool = False,
|
|
) -> torch.Tensor:
|
|
quant_config = mxfp4_w4a16_moe_quant_config(
|
|
w1_bias=w1_bias,
|
|
w2_bias=w2_bias,
|
|
w1_scale=w1_scale,
|
|
w2_scale=w2_scale,
|
|
)
|
|
|
|
if unfused:
|
|
fused_experts = UnfusedOAITritonExperts(quant_config)
|
|
else:
|
|
fused_experts = OAITritonExperts(quant_config)
|
|
|
|
mk = FusedMoEModularKernel(MoEPrepareAndFinalizeNoEP(), fused_experts)
|
|
|
|
return mk.forward(
|
|
hidden_states=x,
|
|
w1=w1,
|
|
w2=w2,
|
|
topk_weights=topk_weights,
|
|
topk_ids=topk_ids,
|
|
inplace=True,
|
|
activation="swigluoai",
|
|
global_num_experts=num_experts,
|
|
expert_map=None,
|
|
apply_router_weight_on_input=False,
|
|
)
|
|
|
|
|
|
@pytest.mark.skipif(
|
|
not current_platform.is_cuda(), reason="This test is skipped on non-CUDA platform."
|
|
)
|
|
@pytest.mark.parametrize("dtype", [torch.bfloat16])
|
|
@pytest.mark.parametrize("m,n,k", MNK)
|
|
@pytest.mark.parametrize("num_experts", [32, 128])
|
|
@pytest.mark.parametrize("topk", [4])
|
|
@pytest.mark.parametrize("unfused", [True, False])
|
|
def test_oai_triton_moe(
|
|
dtype: torch.dtype,
|
|
m: int,
|
|
n: int,
|
|
k: int,
|
|
num_experts: int,
|
|
topk: int,
|
|
unfused: bool,
|
|
):
|
|
current_platform.seed_everything(0)
|
|
(
|
|
w1,
|
|
w2,
|
|
w1_bias,
|
|
w2_bias,
|
|
w1_tri,
|
|
w2_tri,
|
|
w1_bias_tri,
|
|
w2_bias_tri,
|
|
w1_precision_config,
|
|
w2_precision_config,
|
|
) = make_weights(dtype, k, n, num_experts)
|
|
|
|
x = torch.randn((m, k), dtype=dtype, device="cuda")
|
|
router_logits = torch.randn(m, num_experts, device="cuda", dtype=dtype)
|
|
topk_weights, topk_ids = torch.topk(router_logits, k=topk, dim=-1, sorted=True)
|
|
topk_weights = torch.nn.functional.softmax(topk_weights, dim=-1)
|
|
|
|
with set_current_vllm_config(VllmConfig()):
|
|
out_ref = torch_moe_impl(x, w1, w2, w1_bias, w2_bias, topk_weights, topk_ids)
|
|
|
|
out = oai_triton_moe_impl(
|
|
x,
|
|
w1_tri,
|
|
w2_tri,
|
|
w1_precision_config,
|
|
w2_precision_config,
|
|
w1_bias_tri,
|
|
w2_bias_tri,
|
|
num_experts,
|
|
topk_weights,
|
|
topk_ids,
|
|
unfused,
|
|
)
|
|
|
|
assert_close(ref=out_ref, tri=out, maxtol=0.025, rmstol=0.005)
|