[Kernel] Use pre-allocated output buffer for triton kernel fused_experts (#29219)

Signed-off-by: Xin Yang <xyangx@amazon.com>
This commit is contained in:
Xin Yang 2025-11-25 18:21:00 -08:00 committed by GitHub
parent c5ee430328
commit 53d7f1f601
No known key found for this signature in database
GPG Key ID: B5690EEEBB952194

View File

@ -1,6 +1,7 @@
# SPDX-License-Identifier: Apache-2.0
# SPDX-FileCopyrightText: Copyright contributors to the vLLM project
import torch
import vllm.model_executor.layers.fused_moe.modular_kernel as mk
@ -12,6 +13,7 @@ from vllm.model_executor.layers.fused_moe.config import (
from vllm.model_executor.layers.fused_moe.topk_weight_and_reduce import (
TopKWeightAndReduceNoOP,
)
from vllm.model_executor.layers.fused_moe.utils import _resize_cache
from vllm.triton_utils import tl, triton
from vllm.utils.import_utils import has_triton_kernels
@ -88,14 +90,17 @@ def triton_kernel_moe_forward(
gating_output, topk, sm_first=not renormalize
)
output = torch.empty_like(hidden_states)
return triton_kernel_fused_experts(
None,
output,
hidden_states,
w1,
w2,
routing_data,
gather_idx,
scatter_idx,
topk=topk,
activation=activation,
quant_config=quant_config,
apply_router_weight_on_input=apply_router_weight_on_input,
@ -113,6 +118,7 @@ def triton_kernel_fused_experts(
routing_data, # RoutingData
gather_indx, # GatherIndx
scatter_indx, # ScatterIndx
topk: int,
activation: str = "silu",
quant_config: FusedMoEQuantConfig | None = None,
swiglu_alpha: float = 1.702,
@ -120,6 +126,7 @@ def triton_kernel_fused_experts(
apply_router_weight_on_input: bool = False,
global_num_experts: int = -1,
expert_map: torch.Tensor | None = None,
intermediate_cache: torch.Tensor | None = None,
a1q_scale: torch.Tensor | None = None,
) -> torch.Tensor:
if quant_config is None:
@ -131,14 +138,30 @@ def triton_kernel_fused_experts(
assert quant_config.w2_bias is None or quant_config.w2_bias.dtype == torch.float32
# Shape check, only check non-mxfp4
assert hidden_states.ndim == 2
assert hidden_states.shape[-1] == w1.shape[-2]
assert w2.shape[-1] == w1.shape[1]
batch_dim = 1
M, K = hidden_states.shape[-2:]
E, _, N = w1.shape
if global_num_experts == -1:
global_num_experts = E
if intermediate_cache is None:
intermediate_cache = torch.empty(
(batch_dim, M * topk, N // 2),
device=hidden_states.device,
dtype=hidden_states.dtype,
)
# Add batch_dim to output buffer because matmul_ogs expects 3D output
intermediate_cache = _resize_cache(
intermediate_cache, (batch_dim, M * topk, N // 2)
)
output_tensor = _resize_cache(output_tensor, (batch_dim, M, K))
act = FusedActivation(
FnSpecs("swiglu", triton_kernels.swiglu.swiglu_fn, ("alpha", "limit")),
(swiglu_alpha, swiglu_limit),
@ -146,7 +169,7 @@ def triton_kernel_fused_experts(
)
gammas = routing_data.gate_scal if routing_data else None
intermediate_cache1 = matmul_ogs(
matmul_ogs(
hidden_states,
w1,
quant_config.w1_bias,
@ -155,10 +178,11 @@ def triton_kernel_fused_experts(
precision_config=quant_config.w1_precision,
gammas=gammas if apply_router_weight_on_input else None,
fused_activation=act,
y=intermediate_cache,
)
intermediate_cache3 = matmul_ogs(
intermediate_cache1,
matmul_ogs(
intermediate_cache.view(M * topk, N // 2),
w2,
quant_config.w2_bias,
routing_data,
@ -167,7 +191,8 @@ def triton_kernel_fused_experts(
gammas=None if apply_router_weight_on_input else gammas,
y=output_tensor,
)
return intermediate_cache3
output_tensor = output_tensor.view(M, K)
return output_tensor
def make_routing_data(
@ -221,6 +246,42 @@ class BaseOAITritonExperts(mk.FusedMoEPermuteExpertsUnpermute):
def supports_expert_map(self) -> bool:
return True
def moe_problem_size(
self,
a1: torch.Tensor,
w1: torch.Tensor,
w2: torch.Tensor,
topk_ids: torch.Tensor,
) -> tuple[int, int, int, int, int]:
"""
Extract the MoE problem size from the given tensor arguments:
- a: The hidden states, input to the MoE layer.
- w1: The first set of expert weights.
- w2: The second set of expert weights.
- topk_ids: The topk ids.
Note: extracting the problem shape from the weight and activation
tensors is not obvious. It needs to be done this way specifically
due to subtle issues with particular kernels, e.g. the int4 kernels
divide the trailing dimension by two, so it's not "correct" to
extract N or K from the trailing dimension of w1 or w2. Similarly,
some kernels transpose the weights, so this needs to be kept in mind.
Note: This implementation covers most cases. However, if experts
require a specialized implementation, like MarlinExperts, they are free
to override this function.
"""
assert w1.dim() == 3 and w2.dim() == 3
E, _, N = w1.size()
K = a1.size(-1)
assert a1.dim() == 2
assert topk_ids.size(0) == a1.size(0), f"{topk_ids.size(0)} != {a1.size(0)}"
M = a1.size(0)
assert topk_ids.dim() == 2
topk = topk_ids.size(1)
return E, M, N, K, topk
def finalize_weight_and_reduce_impl(self) -> mk.TopKWeightAndReduce:
# Weight application and reduction happens in the fused_experts kernel.
return TopKWeightAndReduceNoOP()
@ -263,8 +324,8 @@ class OAITritonExperts(BaseOAITritonExperts):
expert_tokens_meta: mk.ExpertTokensMetadata | None,
) -> tuple[tuple[int, ...], tuple[int, ...], tuple[int, ...]]:
# workspace are allocated inside the kernel
workspace1 = (M, K)
workspace2 = (0, 0)
workspace1 = (0, 0)
workspace2 = (M * topk, N // 2)
output = (M, K)
return (workspace1, workspace2, output)
@ -297,20 +358,21 @@ class OAITritonExperts(BaseOAITritonExperts):
topk_ids, topk_weights, local_num_experts
)
experts_output = triton_kernel_fused_experts(
None,
topk = topk_ids.size(1)
triton_kernel_fused_experts(
output,
hidden_states,
w1,
w2,
routing_data,
gather_indx,
scatter_indx,
topk=topk,
activation=activation,
quant_config=self.quant_config,
apply_router_weight_on_input=False,
global_num_experts=local_num_experts,
expert_map=None, # applied already
intermediate_cache=workspace2,
a1q_scale=a1q_scale,
)
output.copy_(experts_output, non_blocking=True)