diff --git a/requirements/cuda.txt b/requirements/cuda.txt index dd45eb832a96a..7c5bc457d45b0 100644 --- a/requirements/cuda.txt +++ b/requirements/cuda.txt @@ -13,3 +13,5 @@ torchvision==0.24.0 # Required for phi3v processor. See https://github.com/pytor # xformers==0.0.32.post1; platform_system == 'Linux' and platform_machine == 'x86_64' # Requires PyTorch >= 2.8 # FlashInfer should be updated together with the Dockerfile flashinfer-python==0.4.1 +# Triton Kernels are needed for mxfp4 fused moe. (Should be updated alongside torch) +triton_kernels @ git+https://github.com/triton-lang/triton.git@v3.5.0#subdirectory=python/triton_kernels diff --git a/tests/kernels/moe/test_gpt_oss_triton_kernels.py b/tests/kernels/moe/test_gpt_oss_triton_kernels.py index b8f213d33c963..d4a79a7eff75d 100644 --- a/tests/kernels/moe/test_gpt_oss_triton_kernels.py +++ b/tests/kernels/moe/test_gpt_oss_triton_kernels.py @@ -23,15 +23,9 @@ from triton_kernels.tensor_details import layout from triton_kernels.testing import assert_close from vllm.model_executor.layers.fused_moe.config import FusedMoEQuantConfig -from vllm.model_executor.layers.fused_moe.fused_batched_moe import ( - BatchedPrepareAndFinalize, -) -from vllm.model_executor.layers.fused_moe.fused_moe import fused_topk from vllm.model_executor.layers.fused_moe.gpt_oss_triton_kernels_moe import ( - BatchedOAITritonExperts, triton_kernel_moe_forward, ) -from vllm.model_executor.layers.fused_moe.modular_kernel import FusedMoEModularKernel from vllm.model_executor.layers.utils import shuffle_weight from vllm.utils import round_up @@ -302,8 +296,8 @@ def test_equiv(num_token, a_dtype, w_dtype, tp): quant_config = FusedMoEQuantConfig.make( w1_bias=w1_bias_tri, w2_bias=w2_bias_tri, - w1_precision=pc1, - w2_precision=pc2, + w1_scale=pc1, + w2_scale=pc2, ) out_triton_monolithic = triton_kernel_moe_forward( @@ -329,115 +323,6 @@ def test_equiv(num_token, a_dtype, w_dtype, tp): assert_close(ref=out_ref, tri=out_triton_monolithic, maxtol=0.025, rmstol=0.005) -def batched_moe( - a: torch.Tensor, - w1, - w2, - gating_output: torch.Tensor, - topk: int, - renormalize: bool, - w1_bias: torch.Tensor, - w2_bias: torch.Tensor, - w1_precision: PrecisionConfig, - w2_precision: PrecisionConfig, -) -> torch.Tensor: - max_num_tokens = round_up(a.shape[0], 64) - - quant_config = FusedMoEQuantConfig.make( - w1_precision=w1_precision, - w2_precision=w2_precision, - w1_bias=w1_bias, - w2_bias=w2_bias, - ) - - fused_experts = FusedMoEModularKernel( - BatchedPrepareAndFinalize( - max_num_tokens, - num_dispatchers=1, - num_local_experts=w1.shape[0], - rank=0, - ), - BatchedOAITritonExperts( - max_num_tokens=max_num_tokens, - num_dispatchers=1, - quant_config=quant_config, - ), - ) - - topk_weight, topk_ids, _ = fused_topk(a, gating_output, topk, renormalize) - - return fused_experts( - a, - w1, - w2, - topk_weight, - topk_ids, - ) - - -@pytest.mark.parametrize( - ", ".join(f.name for f in fields(Case)), - [ - tuple(getattr(case, f.name) for f in fields(Case)) - for case in [ - # Case(a_dtype="bf16", w_dtype="bf16"), - # Case(a_dtype="fp8_e4m3", w_dtype="fp8_e5m2"), - Case(a_dtype="bf16", w_dtype="mx4") - ] - ], -) -@pytest.mark.parametrize("num_token", [64]) -@pytest.mark.parametrize("ep", [1, 2, 4, 8]) -def test_triton_kernel_batched_moe(num_token, a_dtype, w_dtype, ep): - M = num_token - E = ModelConfig.num_experts // ep - K = ModelConfig.hidden_size - N = ModelConfig.intermediate_size - topk = ModelConfig.experts_per_token - - ( - x, - w1, - w1_bias, - w2, - w2_bias, - exp_data, - x_tri, - w1_tri, - w2_tri, - exp_data_tri, - w1_bias_tri, - w2_bias_tri, - pc1, - pc2, - ) = init_compute_data(M, K, N, E, a_dtype, w_dtype, num_warps=4) - - out_tri = batched_moe( - a=x_tri, - w1=w1_tri, - w2=w2_tri, - gating_output=exp_data_tri, - topk=topk, - renormalize=True, - w1_bias=w1_bias_tri, - w2_bias=w2_bias_tri, - w1_precision=pc1, - w2_precision=pc2, - ) - out_tri = out_tri[..., :K] - - out_ref = oai_moe_forward( - hidden_states=x, - w1=w1, - w1_bias=w1_bias, - w2=w2, - w2_bias=w2_bias, - gating_output=exp_data, - topk=topk, - ) - assert_close(ref=out_ref, tri=out_tri, maxtol=0.025, rmstol=0.005) - - def test_unit_shuffle(): N = ModelConfig.intermediate_size K = ModelConfig.hidden_size