[Kernel] Use pre-allocated output buffer for triton kernel fused_experts (#29219)

Signed-off-by: Xin Yang <xyangx@amazon.com>
This commit is contained in:
Xin Yang 2025-11-25 18:21:00 -08:00 committed by GitHub
parent c5ee430328
commit 53d7f1f601
No known key found for this signature in database
GPG Key ID: B5690EEEBB952194

View File

@ -1,6 +1,7 @@
# SPDX-License-Identifier: Apache-2.0 # SPDX-License-Identifier: Apache-2.0
# SPDX-FileCopyrightText: Copyright contributors to the vLLM project # SPDX-FileCopyrightText: Copyright contributors to the vLLM project
import torch import torch
import vllm.model_executor.layers.fused_moe.modular_kernel as mk import vllm.model_executor.layers.fused_moe.modular_kernel as mk
@ -12,6 +13,7 @@ from vllm.model_executor.layers.fused_moe.config import (
from vllm.model_executor.layers.fused_moe.topk_weight_and_reduce import ( from vllm.model_executor.layers.fused_moe.topk_weight_and_reduce import (
TopKWeightAndReduceNoOP, TopKWeightAndReduceNoOP,
) )
from vllm.model_executor.layers.fused_moe.utils import _resize_cache
from vllm.triton_utils import tl, triton from vllm.triton_utils import tl, triton
from vllm.utils.import_utils import has_triton_kernels from vllm.utils.import_utils import has_triton_kernels
@ -88,14 +90,17 @@ def triton_kernel_moe_forward(
gating_output, topk, sm_first=not renormalize gating_output, topk, sm_first=not renormalize
) )
output = torch.empty_like(hidden_states)
return triton_kernel_fused_experts( return triton_kernel_fused_experts(
None, output,
hidden_states, hidden_states,
w1, w1,
w2, w2,
routing_data, routing_data,
gather_idx, gather_idx,
scatter_idx, scatter_idx,
topk=topk,
activation=activation, activation=activation,
quant_config=quant_config, quant_config=quant_config,
apply_router_weight_on_input=apply_router_weight_on_input, apply_router_weight_on_input=apply_router_weight_on_input,
@ -113,6 +118,7 @@ def triton_kernel_fused_experts(
routing_data, # RoutingData routing_data, # RoutingData
gather_indx, # GatherIndx gather_indx, # GatherIndx
scatter_indx, # ScatterIndx scatter_indx, # ScatterIndx
topk: int,
activation: str = "silu", activation: str = "silu",
quant_config: FusedMoEQuantConfig | None = None, quant_config: FusedMoEQuantConfig | None = None,
swiglu_alpha: float = 1.702, swiglu_alpha: float = 1.702,
@ -120,6 +126,7 @@ def triton_kernel_fused_experts(
apply_router_weight_on_input: bool = False, apply_router_weight_on_input: bool = False,
global_num_experts: int = -1, global_num_experts: int = -1,
expert_map: torch.Tensor | None = None, expert_map: torch.Tensor | None = None,
intermediate_cache: torch.Tensor | None = None,
a1q_scale: torch.Tensor | None = None, a1q_scale: torch.Tensor | None = None,
) -> torch.Tensor: ) -> torch.Tensor:
if quant_config is None: if quant_config is None:
@ -131,14 +138,30 @@ def triton_kernel_fused_experts(
assert quant_config.w2_bias is None or quant_config.w2_bias.dtype == torch.float32 assert quant_config.w2_bias is None or quant_config.w2_bias.dtype == torch.float32
# Shape check, only check non-mxfp4 # Shape check, only check non-mxfp4
assert hidden_states.ndim == 2
assert hidden_states.shape[-1] == w1.shape[-2] assert hidden_states.shape[-1] == w1.shape[-2]
assert w2.shape[-1] == w1.shape[1] assert w2.shape[-1] == w1.shape[1]
batch_dim = 1
M, K = hidden_states.shape[-2:]
E, _, N = w1.shape E, _, N = w1.shape
if global_num_experts == -1: if global_num_experts == -1:
global_num_experts = E global_num_experts = E
if intermediate_cache is None:
intermediate_cache = torch.empty(
(batch_dim, M * topk, N // 2),
device=hidden_states.device,
dtype=hidden_states.dtype,
)
# Add batch_dim to output buffer because matmul_ogs expects 3D output
intermediate_cache = _resize_cache(
intermediate_cache, (batch_dim, M * topk, N // 2)
)
output_tensor = _resize_cache(output_tensor, (batch_dim, M, K))
act = FusedActivation( act = FusedActivation(
FnSpecs("swiglu", triton_kernels.swiglu.swiglu_fn, ("alpha", "limit")), FnSpecs("swiglu", triton_kernels.swiglu.swiglu_fn, ("alpha", "limit")),
(swiglu_alpha, swiglu_limit), (swiglu_alpha, swiglu_limit),
@ -146,7 +169,7 @@ def triton_kernel_fused_experts(
) )
gammas = routing_data.gate_scal if routing_data else None gammas = routing_data.gate_scal if routing_data else None
intermediate_cache1 = matmul_ogs( matmul_ogs(
hidden_states, hidden_states,
w1, w1,
quant_config.w1_bias, quant_config.w1_bias,
@ -155,10 +178,11 @@ def triton_kernel_fused_experts(
precision_config=quant_config.w1_precision, precision_config=quant_config.w1_precision,
gammas=gammas if apply_router_weight_on_input else None, gammas=gammas if apply_router_weight_on_input else None,
fused_activation=act, fused_activation=act,
y=intermediate_cache,
) )
intermediate_cache3 = matmul_ogs( matmul_ogs(
intermediate_cache1, intermediate_cache.view(M * topk, N // 2),
w2, w2,
quant_config.w2_bias, quant_config.w2_bias,
routing_data, routing_data,
@ -167,7 +191,8 @@ def triton_kernel_fused_experts(
gammas=None if apply_router_weight_on_input else gammas, gammas=None if apply_router_weight_on_input else gammas,
y=output_tensor, y=output_tensor,
) )
return intermediate_cache3 output_tensor = output_tensor.view(M, K)
return output_tensor
def make_routing_data( def make_routing_data(
@ -221,6 +246,42 @@ class BaseOAITritonExperts(mk.FusedMoEPermuteExpertsUnpermute):
def supports_expert_map(self) -> bool: def supports_expert_map(self) -> bool:
return True return True
def moe_problem_size(
self,
a1: torch.Tensor,
w1: torch.Tensor,
w2: torch.Tensor,
topk_ids: torch.Tensor,
) -> tuple[int, int, int, int, int]:
"""
Extract the MoE problem size from the given tensor arguments:
- a: The hidden states, input to the MoE layer.
- w1: The first set of expert weights.
- w2: The second set of expert weights.
- topk_ids: The topk ids.
Note: extracting the problem shape from the weight and activation
tensors is not obvious. It needs to be done this way specifically
due to subtle issues with particular kernels, e.g. the int4 kernels
divide the trailing dimension by two, so it's not "correct" to
extract N or K from the trailing dimension of w1 or w2. Similarly,
some kernels transpose the weights, so this needs to be kept in mind.
Note: This implementation covers most cases. However, if experts
require a specialized implementation, like MarlinExperts, they are free
to override this function.
"""
assert w1.dim() == 3 and w2.dim() == 3
E, _, N = w1.size()
K = a1.size(-1)
assert a1.dim() == 2
assert topk_ids.size(0) == a1.size(0), f"{topk_ids.size(0)} != {a1.size(0)}"
M = a1.size(0)
assert topk_ids.dim() == 2
topk = topk_ids.size(1)
return E, M, N, K, topk
def finalize_weight_and_reduce_impl(self) -> mk.TopKWeightAndReduce: def finalize_weight_and_reduce_impl(self) -> mk.TopKWeightAndReduce:
# Weight application and reduction happens in the fused_experts kernel. # Weight application and reduction happens in the fused_experts kernel.
return TopKWeightAndReduceNoOP() return TopKWeightAndReduceNoOP()
@ -263,8 +324,8 @@ class OAITritonExperts(BaseOAITritonExperts):
expert_tokens_meta: mk.ExpertTokensMetadata | None, expert_tokens_meta: mk.ExpertTokensMetadata | None,
) -> tuple[tuple[int, ...], tuple[int, ...], tuple[int, ...]]: ) -> tuple[tuple[int, ...], tuple[int, ...], tuple[int, ...]]:
# workspace are allocated inside the kernel # workspace are allocated inside the kernel
workspace1 = (M, K) workspace1 = (0, 0)
workspace2 = (0, 0) workspace2 = (M * topk, N // 2)
output = (M, K) output = (M, K)
return (workspace1, workspace2, output) return (workspace1, workspace2, output)
@ -297,20 +358,21 @@ class OAITritonExperts(BaseOAITritonExperts):
topk_ids, topk_weights, local_num_experts topk_ids, topk_weights, local_num_experts
) )
experts_output = triton_kernel_fused_experts( topk = topk_ids.size(1)
None, triton_kernel_fused_experts(
output,
hidden_states, hidden_states,
w1, w1,
w2, w2,
routing_data, routing_data,
gather_indx, gather_indx,
scatter_indx, scatter_indx,
topk=topk,
activation=activation, activation=activation,
quant_config=self.quant_config, quant_config=self.quant_config,
apply_router_weight_on_input=False, apply_router_weight_on_input=False,
global_num_experts=local_num_experts, global_num_experts=local_num_experts,
expert_map=None, # applied already expert_map=None, # applied already
intermediate_cache=workspace2,
a1q_scale=a1q_scale, a1q_scale=a1q_scale,
) )
output.copy_(experts_output, non_blocking=True)