From 53d7f1f601a12b8fa58aa0bffb7fc27d63d1eb5e Mon Sep 17 00:00:00 2001 From: Xin Yang <105740670+xyang16@users.noreply.github.com> Date: Tue, 25 Nov 2025 18:21:00 -0800 Subject: [PATCH] [Kernel] Use pre-allocated output buffer for triton kernel fused_experts (#29219) Signed-off-by: Xin Yang --- .../fused_moe/gpt_oss_triton_kernels_moe.py | 84 ++++++++++++++++--- 1 file changed, 73 insertions(+), 11 deletions(-) diff --git a/vllm/model_executor/layers/fused_moe/gpt_oss_triton_kernels_moe.py b/vllm/model_executor/layers/fused_moe/gpt_oss_triton_kernels_moe.py index badedfc54c382..128507639fdfd 100644 --- a/vllm/model_executor/layers/fused_moe/gpt_oss_triton_kernels_moe.py +++ b/vllm/model_executor/layers/fused_moe/gpt_oss_triton_kernels_moe.py @@ -1,6 +1,7 @@ # SPDX-License-Identifier: Apache-2.0 # SPDX-FileCopyrightText: Copyright contributors to the vLLM project + import torch import vllm.model_executor.layers.fused_moe.modular_kernel as mk @@ -12,6 +13,7 @@ from vllm.model_executor.layers.fused_moe.config import ( from vllm.model_executor.layers.fused_moe.topk_weight_and_reduce import ( TopKWeightAndReduceNoOP, ) +from vllm.model_executor.layers.fused_moe.utils import _resize_cache from vllm.triton_utils import tl, triton from vllm.utils.import_utils import has_triton_kernels @@ -88,14 +90,17 @@ def triton_kernel_moe_forward( gating_output, topk, sm_first=not renormalize ) + output = torch.empty_like(hidden_states) + return triton_kernel_fused_experts( - None, + output, hidden_states, w1, w2, routing_data, gather_idx, scatter_idx, + topk=topk, activation=activation, quant_config=quant_config, apply_router_weight_on_input=apply_router_weight_on_input, @@ -113,6 +118,7 @@ def triton_kernel_fused_experts( routing_data, # RoutingData gather_indx, # GatherIndx scatter_indx, # ScatterIndx + topk: int, activation: str = "silu", quant_config: FusedMoEQuantConfig | None = None, swiglu_alpha: float = 1.702, @@ -120,6 +126,7 @@ def triton_kernel_fused_experts( apply_router_weight_on_input: bool = False, global_num_experts: int = -1, expert_map: torch.Tensor | None = None, + intermediate_cache: torch.Tensor | None = None, a1q_scale: torch.Tensor | None = None, ) -> torch.Tensor: if quant_config is None: @@ -131,14 +138,30 @@ def triton_kernel_fused_experts( assert quant_config.w2_bias is None or quant_config.w2_bias.dtype == torch.float32 # Shape check, only check non-mxfp4 + assert hidden_states.ndim == 2 assert hidden_states.shape[-1] == w1.shape[-2] assert w2.shape[-1] == w1.shape[1] + batch_dim = 1 + M, K = hidden_states.shape[-2:] E, _, N = w1.shape if global_num_experts == -1: global_num_experts = E + if intermediate_cache is None: + intermediate_cache = torch.empty( + (batch_dim, M * topk, N // 2), + device=hidden_states.device, + dtype=hidden_states.dtype, + ) + + # Add batch_dim to output buffer because matmul_ogs expects 3D output + intermediate_cache = _resize_cache( + intermediate_cache, (batch_dim, M * topk, N // 2) + ) + output_tensor = _resize_cache(output_tensor, (batch_dim, M, K)) + act = FusedActivation( FnSpecs("swiglu", triton_kernels.swiglu.swiglu_fn, ("alpha", "limit")), (swiglu_alpha, swiglu_limit), @@ -146,7 +169,7 @@ def triton_kernel_fused_experts( ) gammas = routing_data.gate_scal if routing_data else None - intermediate_cache1 = matmul_ogs( + matmul_ogs( hidden_states, w1, quant_config.w1_bias, @@ -155,10 +178,11 @@ def triton_kernel_fused_experts( precision_config=quant_config.w1_precision, gammas=gammas if apply_router_weight_on_input else None, fused_activation=act, + y=intermediate_cache, ) - intermediate_cache3 = matmul_ogs( - intermediate_cache1, + matmul_ogs( + intermediate_cache.view(M * topk, N // 2), w2, quant_config.w2_bias, routing_data, @@ -167,7 +191,8 @@ def triton_kernel_fused_experts( gammas=None if apply_router_weight_on_input else gammas, y=output_tensor, ) - return intermediate_cache3 + output_tensor = output_tensor.view(M, K) + return output_tensor def make_routing_data( @@ -221,6 +246,42 @@ class BaseOAITritonExperts(mk.FusedMoEPermuteExpertsUnpermute): def supports_expert_map(self) -> bool: return True + def moe_problem_size( + self, + a1: torch.Tensor, + w1: torch.Tensor, + w2: torch.Tensor, + topk_ids: torch.Tensor, + ) -> tuple[int, int, int, int, int]: + """ + Extract the MoE problem size from the given tensor arguments: + - a: The hidden states, input to the MoE layer. + - w1: The first set of expert weights. + - w2: The second set of expert weights. + - topk_ids: The topk ids. + Note: extracting the problem shape from the weight and activation + tensors is not obvious. It needs to be done this way specifically + due to subtle issues with particular kernels, e.g. the int4 kernels + divide the trailing dimension by two, so it's not "correct" to + extract N or K from the trailing dimension of w1 or w2. Similarly, + some kernels transpose the weights, so this needs to be kept in mind. + Note: This implementation covers most cases. However, if experts + require a specialized implementation, like MarlinExperts, they are free + to override this function. + """ + assert w1.dim() == 3 and w2.dim() == 3 + E, _, N = w1.size() + K = a1.size(-1) + + assert a1.dim() == 2 + assert topk_ids.size(0) == a1.size(0), f"{topk_ids.size(0)} != {a1.size(0)}" + M = a1.size(0) + + assert topk_ids.dim() == 2 + topk = topk_ids.size(1) + + return E, M, N, K, topk + def finalize_weight_and_reduce_impl(self) -> mk.TopKWeightAndReduce: # Weight application and reduction happens in the fused_experts kernel. return TopKWeightAndReduceNoOP() @@ -263,8 +324,8 @@ class OAITritonExperts(BaseOAITritonExperts): expert_tokens_meta: mk.ExpertTokensMetadata | None, ) -> tuple[tuple[int, ...], tuple[int, ...], tuple[int, ...]]: # workspace are allocated inside the kernel - workspace1 = (M, K) - workspace2 = (0, 0) + workspace1 = (0, 0) + workspace2 = (M * topk, N // 2) output = (M, K) return (workspace1, workspace2, output) @@ -297,20 +358,21 @@ class OAITritonExperts(BaseOAITritonExperts): topk_ids, topk_weights, local_num_experts ) - experts_output = triton_kernel_fused_experts( - None, + topk = topk_ids.size(1) + triton_kernel_fused_experts( + output, hidden_states, w1, w2, routing_data, gather_indx, scatter_indx, + topk=topk, activation=activation, quant_config=self.quant_config, apply_router_weight_on_input=False, global_num_experts=local_num_experts, expert_map=None, # applied already + intermediate_cache=workspace2, a1q_scale=a1q_scale, ) - - output.copy_(experts_output, non_blocking=True)