diff --git a/tests/kernels/moe/test_gpt_oss_triton_kernels.py b/tests/kernels/moe/test_gpt_oss_triton_kernels.py index 3f9b32ce5a368..54f2351bf6d9b 100644 --- a/tests/kernels/moe/test_gpt_oss_triton_kernels.py +++ b/tests/kernels/moe/test_gpt_oss_triton_kernels.py @@ -5,6 +5,15 @@ from dataclasses import dataclass, fields import pytest import torch import torch.nn.functional as F + +from vllm.utils import has_triton_kernels + +if not has_triton_kernels(): + pytest.skip( + "triton_kernels not found, skipping all related tests", + allow_module_level=True, + ) + import triton_kernels.swiglu from triton_kernels.matmul_ogs import FlexCtx, PrecisionConfig from triton_kernels.numerics import InFlexData @@ -65,7 +74,7 @@ def init_compute_data(M, K, N, E, a_dtype: str, w_dtype: str, num_warps: int): dtype_dict = { "bf16": torch.bfloat16, "fp8_e4m3": torch.float8_e4m3fn, - "fp8_e5m2": torch.float8_e5m2 + "fp8_e5m2": torch.float8_e5m2, } x = x.to(dtype_dict[a_dtype]).to(torch.bfloat16) @@ -97,12 +106,18 @@ def init_compute_data(M, K, N, E, a_dtype: str, w_dtype: str, num_warps: int): x_pad = w1_bottom_pad - w1_tri = F.pad(w1_tri, (0, w1_right_pad, 0, w1_bottom_pad, 0, 0), - mode="constant", - value=0) - w2_tri = F.pad(w2_tri, (0, w2_right_pad, 0, w2_bottom_pad, 0, 0), - mode="constant", - value=0) + w1_tri = F.pad( + w1_tri, + (0, w1_right_pad, 0, w1_bottom_pad, 0, 0), + mode="constant", + value=0, + ) + w2_tri = F.pad( + w2_tri, + (0, w2_right_pad, 0, w2_bottom_pad, 0, 0), + mode="constant", + value=0, + ) w1_bias_tri = F.pad(w1_bias_tri, (0, w1_right_pad, 0, 0), mode="constant", @@ -127,13 +142,19 @@ def init_compute_data(M, K, N, E, a_dtype: str, w_dtype: str, num_warps: int): w1_tri = convert_layout(wrap_torch_tensor(w1_tri, FP4), w_layout, **w_layout_opts) - w1_scale_tri = convert_layout(wrap_torch_tensor(w1_scale_tri), - w_scale_layout, **w_scale_layout_opts) + w1_scale_tri = convert_layout( + wrap_torch_tensor(w1_scale_tri), + w_scale_layout, + **w_scale_layout_opts, + ) w2_tri = convert_layout(wrap_torch_tensor(w2_tri, FP4), w_layout, **w_layout_opts) - w2_scale_tri = convert_layout(wrap_torch_tensor(w2_scale_tri), - w_scale_layout, **w_scale_layout_opts) + w2_scale_tri = convert_layout( + wrap_torch_tensor(w2_scale_tri), + w_scale_layout, + **w_scale_layout_opts, + ) pc1 = PrecisionConfig(weight_scale=w1_scale_tri, flex_ctx=FlexCtx(rhs_data=InFlexData())) @@ -149,8 +170,22 @@ def init_compute_data(M, K, N, E, a_dtype: str, w_dtype: str, num_warps: int): w1 = w1.transpose(-1, -2).contiguous() w2 = w2.transpose(-1, -2).contiguous() - return (x, w1, w1_bias, w2, w2_bias, exp_data, x_tri, w1_tri, w2_tri, - exp_data_tri, w1_bias_tri, w2_bias_tri, pc1, pc2) + return ( + x, + w1, + w1_bias, + w2, + w2_bias, + exp_data, + x_tri, + w1_tri, + w2_tri, + exp_data_tri, + w1_bias_tri, + w2_bias_tri, + pc1, + pc2, + ) @dataclass @@ -184,13 +219,14 @@ def swiglu(x, alpha: float = 1.702, limit: float = 1.0): def oai_moe_forward( - hidden_states: torch.Tensor, # (M, K) - w1: torch.Tensor, # (E, 2N) - w1_bias: torch.Tensor, # (E, 2N, K) - w2: torch.Tensor, # (E, K, N) - w2_bias: torch.Tensor, # (E, N) - gating_output: torch.Tensor, # (M, E) - topk: int): + hidden_states: torch.Tensor, # (M, K) + w1: torch.Tensor, # (E, 2N) + w1_bias: torch.Tensor, # (E, 2N, K) + w2: torch.Tensor, # (E, K, N) + w2_bias: torch.Tensor, # (E, N) + gating_output: torch.Tensor, # (M, E) + topk: int, +): # model.py 309:330, assuming gating and norm t = hidden_states experts = torch.topk(gating_output, k=topk, dim=-1, sorted=True) @@ -240,10 +276,22 @@ def test_equiv(num_token, a_dtype, w_dtype, tp): N = ModelConfig.intermediate_size // tp topk = ModelConfig.experts_per_token - x, w1, w1_bias, w2, w2_bias, exp_data, \ - x_tri, w1_tri, w2_tri, exp_data_tri, w1_bias_tri,\ - w2_bias_tri, pc1, pc2 = init_compute_data( - M, K, N, E, a_dtype, w_dtype, num_warps=8) + ( + x, + w1, + w1_bias, + w2, + w2_bias, + exp_data, + x_tri, + w1_tri, + w2_tri, + exp_data_tri, + w1_bias_tri, + w2_bias_tri, + pc1, + pc2, + ) = init_compute_data(M, K, N, E, a_dtype, w_dtype, num_warps=8) out_triton_monolithic = triton_kernel_moe_forward( hidden_states=x_tri, @@ -255,33 +303,46 @@ def test_equiv(num_token, a_dtype, w_dtype, tp): w1_bias=w1_bias_tri, w2_bias=w2_bias_tri, w1_precision=pc1, - w2_precision=pc2) + w2_precision=pc2, + ) out_triton_monolithic = out_triton_monolithic[..., :K] - out_ref = oai_moe_forward(hidden_states=x, - w1=w1, - w1_bias=w1_bias, - w2=w2, - w2_bias=w2_bias, - gating_output=exp_data, - topk=topk) + out_ref = oai_moe_forward( + hidden_states=x, + w1=w1, + w1_bias=w1_bias, + w2=w2, + w2_bias=w2_bias, + gating_output=exp_data, + topk=topk, + ) assert_close(ref=out_ref, tri=out_triton_monolithic, maxtol=0.025, rmstol=0.005) -def batched_moe(a: torch.Tensor, w1, w2, gating_output: torch.Tensor, - topk: int, renormalize: bool, w1_bias: torch.Tensor, - w2_bias: torch.Tensor, w1_precision: PrecisionConfig, - w2_precision: PrecisionConfig) -> torch.Tensor: +def batched_moe( + a: torch.Tensor, + w1, + w2, + gating_output: torch.Tensor, + topk: int, + renormalize: bool, + w1_bias: torch.Tensor, + w2_bias: torch.Tensor, + w1_precision: PrecisionConfig, + w2_precision: PrecisionConfig, +) -> torch.Tensor: max_num_tokens = round_up(a.shape[0], 64) fused_experts = FusedMoEModularKernel( - BatchedPrepareAndFinalize(max_num_tokens, - num_dispatchers=1, - num_local_experts=w1.shape[0], - rank=0), + BatchedPrepareAndFinalize( + max_num_tokens, + num_dispatchers=1, + num_local_experts=w1.shape[0], + rank=0, + ), BatchedOAITritonExperts( None, max_num_tokens=max_num_tokens, @@ -327,30 +388,46 @@ def test_triton_kernel_batched_moe(num_token, a_dtype, w_dtype, ep): N = ModelConfig.intermediate_size topk = ModelConfig.experts_per_token - x, w1, w1_bias, w2, w2_bias, exp_data, \ - x_tri, w1_tri, w2_tri, exp_data_tri, w1_bias_tri, \ - w2_bias_tri, pc1, pc2 = init_compute_data( - M, K, N, E, a_dtype, w_dtype, num_warps=4) + ( + x, + w1, + w1_bias, + w2, + w2_bias, + exp_data, + x_tri, + w1_tri, + w2_tri, + exp_data_tri, + w1_bias_tri, + w2_bias_tri, + pc1, + pc2, + ) = init_compute_data(M, K, N, E, a_dtype, w_dtype, num_warps=4) - out_tri = batched_moe(a=x_tri, - w1=w1_tri, - w2=w2_tri, - gating_output=exp_data_tri, - topk=topk, - renormalize=True, - w1_bias=w1_bias_tri, - w2_bias=w2_bias_tri, - w1_precision=pc1, - w2_precision=pc2) + out_tri = batched_moe( + a=x_tri, + w1=w1_tri, + w2=w2_tri, + gating_output=exp_data_tri, + topk=topk, + renormalize=True, + w1_bias=w1_bias_tri, + w2_bias=w2_bias_tri, + w1_precision=pc1, + w2_precision=pc2, + ) out_tri = out_tri[..., :K] - out_ref = oai_moe_forward(hidden_states=x, - w1=w1, - w1_bias=w1_bias, - w2=w2, - w2_bias=w2_bias, - gating_output=exp_data, - topk=topk) + out_ref = oai_moe_forward( + hidden_states=x, + w1=w1, + w1_bias=w1_bias, + w2=w2, + w2_bias=w2_bias, + gating_output=exp_data, + topk=topk, + ) assert_close(ref=out_ref, tri=out_tri, maxtol=0.025, rmstol=0.005) @@ -370,6 +447,7 @@ def test_unit_shuffle(): out = triton_kernels.swiglu.swiglu_torch( out, alpha=1.702, - precision_config=triton_kernels.swiglu.PrecisionConfig(limit=1.0)) + precision_config=triton_kernels.swiglu.PrecisionConfig(limit=1.0), + ) - assert_close(ref=out_ref, tri=out) \ No newline at end of file + assert_close(ref=out_ref, tri=out)