mirror of
https://git.datalinker.icu/vllm-project/vllm.git
synced 2025-12-10 05:04:58 +08:00
[Bugfix] Fix CI moe kernel failure (#22556)
Signed-off-by: Jee Jee Li <pandaleefree@gmail.com>
This commit is contained in:
parent
7920e9b1c5
commit
0edc0cd52b
@ -5,6 +5,15 @@ from dataclasses import dataclass, fields
|
||||
import pytest
|
||||
import torch
|
||||
import torch.nn.functional as F
|
||||
|
||||
from vllm.utils import has_triton_kernels
|
||||
|
||||
if not has_triton_kernels():
|
||||
pytest.skip(
|
||||
"triton_kernels not found, skipping all related tests",
|
||||
allow_module_level=True,
|
||||
)
|
||||
|
||||
import triton_kernels.swiglu
|
||||
from triton_kernels.matmul_ogs import FlexCtx, PrecisionConfig
|
||||
from triton_kernels.numerics import InFlexData
|
||||
@ -65,7 +74,7 @@ def init_compute_data(M, K, N, E, a_dtype: str, w_dtype: str, num_warps: int):
|
||||
dtype_dict = {
|
||||
"bf16": torch.bfloat16,
|
||||
"fp8_e4m3": torch.float8_e4m3fn,
|
||||
"fp8_e5m2": torch.float8_e5m2
|
||||
"fp8_e5m2": torch.float8_e5m2,
|
||||
}
|
||||
|
||||
x = x.to(dtype_dict[a_dtype]).to(torch.bfloat16)
|
||||
@ -97,12 +106,18 @@ def init_compute_data(M, K, N, E, a_dtype: str, w_dtype: str, num_warps: int):
|
||||
|
||||
x_pad = w1_bottom_pad
|
||||
|
||||
w1_tri = F.pad(w1_tri, (0, w1_right_pad, 0, w1_bottom_pad, 0, 0),
|
||||
mode="constant",
|
||||
value=0)
|
||||
w2_tri = F.pad(w2_tri, (0, w2_right_pad, 0, w2_bottom_pad, 0, 0),
|
||||
mode="constant",
|
||||
value=0)
|
||||
w1_tri = F.pad(
|
||||
w1_tri,
|
||||
(0, w1_right_pad, 0, w1_bottom_pad, 0, 0),
|
||||
mode="constant",
|
||||
value=0,
|
||||
)
|
||||
w2_tri = F.pad(
|
||||
w2_tri,
|
||||
(0, w2_right_pad, 0, w2_bottom_pad, 0, 0),
|
||||
mode="constant",
|
||||
value=0,
|
||||
)
|
||||
|
||||
w1_bias_tri = F.pad(w1_bias_tri, (0, w1_right_pad, 0, 0),
|
||||
mode="constant",
|
||||
@ -127,13 +142,19 @@ def init_compute_data(M, K, N, E, a_dtype: str, w_dtype: str, num_warps: int):
|
||||
|
||||
w1_tri = convert_layout(wrap_torch_tensor(w1_tri, FP4), w_layout,
|
||||
**w_layout_opts)
|
||||
w1_scale_tri = convert_layout(wrap_torch_tensor(w1_scale_tri),
|
||||
w_scale_layout, **w_scale_layout_opts)
|
||||
w1_scale_tri = convert_layout(
|
||||
wrap_torch_tensor(w1_scale_tri),
|
||||
w_scale_layout,
|
||||
**w_scale_layout_opts,
|
||||
)
|
||||
|
||||
w2_tri = convert_layout(wrap_torch_tensor(w2_tri, FP4), w_layout,
|
||||
**w_layout_opts)
|
||||
w2_scale_tri = convert_layout(wrap_torch_tensor(w2_scale_tri),
|
||||
w_scale_layout, **w_scale_layout_opts)
|
||||
w2_scale_tri = convert_layout(
|
||||
wrap_torch_tensor(w2_scale_tri),
|
||||
w_scale_layout,
|
||||
**w_scale_layout_opts,
|
||||
)
|
||||
|
||||
pc1 = PrecisionConfig(weight_scale=w1_scale_tri,
|
||||
flex_ctx=FlexCtx(rhs_data=InFlexData()))
|
||||
@ -149,8 +170,22 @@ def init_compute_data(M, K, N, E, a_dtype: str, w_dtype: str, num_warps: int):
|
||||
w1 = w1.transpose(-1, -2).contiguous()
|
||||
w2 = w2.transpose(-1, -2).contiguous()
|
||||
|
||||
return (x, w1, w1_bias, w2, w2_bias, exp_data, x_tri, w1_tri, w2_tri,
|
||||
exp_data_tri, w1_bias_tri, w2_bias_tri, pc1, pc2)
|
||||
return (
|
||||
x,
|
||||
w1,
|
||||
w1_bias,
|
||||
w2,
|
||||
w2_bias,
|
||||
exp_data,
|
||||
x_tri,
|
||||
w1_tri,
|
||||
w2_tri,
|
||||
exp_data_tri,
|
||||
w1_bias_tri,
|
||||
w2_bias_tri,
|
||||
pc1,
|
||||
pc2,
|
||||
)
|
||||
|
||||
|
||||
@dataclass
|
||||
@ -184,13 +219,14 @@ def swiglu(x, alpha: float = 1.702, limit: float = 1.0):
|
||||
|
||||
|
||||
def oai_moe_forward(
|
||||
hidden_states: torch.Tensor, # (M, K)
|
||||
w1: torch.Tensor, # (E, 2N)
|
||||
w1_bias: torch.Tensor, # (E, 2N, K)
|
||||
w2: torch.Tensor, # (E, K, N)
|
||||
w2_bias: torch.Tensor, # (E, N)
|
||||
gating_output: torch.Tensor, # (M, E)
|
||||
topk: int):
|
||||
hidden_states: torch.Tensor, # (M, K)
|
||||
w1: torch.Tensor, # (E, 2N)
|
||||
w1_bias: torch.Tensor, # (E, 2N, K)
|
||||
w2: torch.Tensor, # (E, K, N)
|
||||
w2_bias: torch.Tensor, # (E, N)
|
||||
gating_output: torch.Tensor, # (M, E)
|
||||
topk: int,
|
||||
):
|
||||
# model.py 309:330, assuming gating and norm
|
||||
t = hidden_states
|
||||
experts = torch.topk(gating_output, k=topk, dim=-1, sorted=True)
|
||||
@ -240,10 +276,22 @@ def test_equiv(num_token, a_dtype, w_dtype, tp):
|
||||
N = ModelConfig.intermediate_size // tp
|
||||
topk = ModelConfig.experts_per_token
|
||||
|
||||
x, w1, w1_bias, w2, w2_bias, exp_data, \
|
||||
x_tri, w1_tri, w2_tri, exp_data_tri, w1_bias_tri,\
|
||||
w2_bias_tri, pc1, pc2 = init_compute_data(
|
||||
M, K, N, E, a_dtype, w_dtype, num_warps=8)
|
||||
(
|
||||
x,
|
||||
w1,
|
||||
w1_bias,
|
||||
w2,
|
||||
w2_bias,
|
||||
exp_data,
|
||||
x_tri,
|
||||
w1_tri,
|
||||
w2_tri,
|
||||
exp_data_tri,
|
||||
w1_bias_tri,
|
||||
w2_bias_tri,
|
||||
pc1,
|
||||
pc2,
|
||||
) = init_compute_data(M, K, N, E, a_dtype, w_dtype, num_warps=8)
|
||||
|
||||
out_triton_monolithic = triton_kernel_moe_forward(
|
||||
hidden_states=x_tri,
|
||||
@ -255,33 +303,46 @@ def test_equiv(num_token, a_dtype, w_dtype, tp):
|
||||
w1_bias=w1_bias_tri,
|
||||
w2_bias=w2_bias_tri,
|
||||
w1_precision=pc1,
|
||||
w2_precision=pc2)
|
||||
w2_precision=pc2,
|
||||
)
|
||||
out_triton_monolithic = out_triton_monolithic[..., :K]
|
||||
|
||||
out_ref = oai_moe_forward(hidden_states=x,
|
||||
w1=w1,
|
||||
w1_bias=w1_bias,
|
||||
w2=w2,
|
||||
w2_bias=w2_bias,
|
||||
gating_output=exp_data,
|
||||
topk=topk)
|
||||
out_ref = oai_moe_forward(
|
||||
hidden_states=x,
|
||||
w1=w1,
|
||||
w1_bias=w1_bias,
|
||||
w2=w2,
|
||||
w2_bias=w2_bias,
|
||||
gating_output=exp_data,
|
||||
topk=topk,
|
||||
)
|
||||
assert_close(ref=out_ref,
|
||||
tri=out_triton_monolithic,
|
||||
maxtol=0.025,
|
||||
rmstol=0.005)
|
||||
|
||||
|
||||
def batched_moe(a: torch.Tensor, w1, w2, gating_output: torch.Tensor,
|
||||
topk: int, renormalize: bool, w1_bias: torch.Tensor,
|
||||
w2_bias: torch.Tensor, w1_precision: PrecisionConfig,
|
||||
w2_precision: PrecisionConfig) -> torch.Tensor:
|
||||
def batched_moe(
|
||||
a: torch.Tensor,
|
||||
w1,
|
||||
w2,
|
||||
gating_output: torch.Tensor,
|
||||
topk: int,
|
||||
renormalize: bool,
|
||||
w1_bias: torch.Tensor,
|
||||
w2_bias: torch.Tensor,
|
||||
w1_precision: PrecisionConfig,
|
||||
w2_precision: PrecisionConfig,
|
||||
) -> torch.Tensor:
|
||||
max_num_tokens = round_up(a.shape[0], 64)
|
||||
|
||||
fused_experts = FusedMoEModularKernel(
|
||||
BatchedPrepareAndFinalize(max_num_tokens,
|
||||
num_dispatchers=1,
|
||||
num_local_experts=w1.shape[0],
|
||||
rank=0),
|
||||
BatchedPrepareAndFinalize(
|
||||
max_num_tokens,
|
||||
num_dispatchers=1,
|
||||
num_local_experts=w1.shape[0],
|
||||
rank=0,
|
||||
),
|
||||
BatchedOAITritonExperts(
|
||||
None,
|
||||
max_num_tokens=max_num_tokens,
|
||||
@ -327,30 +388,46 @@ def test_triton_kernel_batched_moe(num_token, a_dtype, w_dtype, ep):
|
||||
N = ModelConfig.intermediate_size
|
||||
topk = ModelConfig.experts_per_token
|
||||
|
||||
x, w1, w1_bias, w2, w2_bias, exp_data, \
|
||||
x_tri, w1_tri, w2_tri, exp_data_tri, w1_bias_tri, \
|
||||
w2_bias_tri, pc1, pc2 = init_compute_data(
|
||||
M, K, N, E, a_dtype, w_dtype, num_warps=4)
|
||||
(
|
||||
x,
|
||||
w1,
|
||||
w1_bias,
|
||||
w2,
|
||||
w2_bias,
|
||||
exp_data,
|
||||
x_tri,
|
||||
w1_tri,
|
||||
w2_tri,
|
||||
exp_data_tri,
|
||||
w1_bias_tri,
|
||||
w2_bias_tri,
|
||||
pc1,
|
||||
pc2,
|
||||
) = init_compute_data(M, K, N, E, a_dtype, w_dtype, num_warps=4)
|
||||
|
||||
out_tri = batched_moe(a=x_tri,
|
||||
w1=w1_tri,
|
||||
w2=w2_tri,
|
||||
gating_output=exp_data_tri,
|
||||
topk=topk,
|
||||
renormalize=True,
|
||||
w1_bias=w1_bias_tri,
|
||||
w2_bias=w2_bias_tri,
|
||||
w1_precision=pc1,
|
||||
w2_precision=pc2)
|
||||
out_tri = batched_moe(
|
||||
a=x_tri,
|
||||
w1=w1_tri,
|
||||
w2=w2_tri,
|
||||
gating_output=exp_data_tri,
|
||||
topk=topk,
|
||||
renormalize=True,
|
||||
w1_bias=w1_bias_tri,
|
||||
w2_bias=w2_bias_tri,
|
||||
w1_precision=pc1,
|
||||
w2_precision=pc2,
|
||||
)
|
||||
out_tri = out_tri[..., :K]
|
||||
|
||||
out_ref = oai_moe_forward(hidden_states=x,
|
||||
w1=w1,
|
||||
w1_bias=w1_bias,
|
||||
w2=w2,
|
||||
w2_bias=w2_bias,
|
||||
gating_output=exp_data,
|
||||
topk=topk)
|
||||
out_ref = oai_moe_forward(
|
||||
hidden_states=x,
|
||||
w1=w1,
|
||||
w1_bias=w1_bias,
|
||||
w2=w2,
|
||||
w2_bias=w2_bias,
|
||||
gating_output=exp_data,
|
||||
topk=topk,
|
||||
)
|
||||
assert_close(ref=out_ref, tri=out_tri, maxtol=0.025, rmstol=0.005)
|
||||
|
||||
|
||||
@ -370,6 +447,7 @@ def test_unit_shuffle():
|
||||
out = triton_kernels.swiglu.swiglu_torch(
|
||||
out,
|
||||
alpha=1.702,
|
||||
precision_config=triton_kernels.swiglu.PrecisionConfig(limit=1.0))
|
||||
precision_config=triton_kernels.swiglu.PrecisionConfig(limit=1.0),
|
||||
)
|
||||
|
||||
assert_close(ref=out_ref, tri=out)
|
||||
assert_close(ref=out_ref, tri=out)
|
||||
|
||||
Loading…
x
Reference in New Issue
Block a user