mirror of
https://git.datalinker.icu/vllm-project/vllm.git
synced 2025-12-24 13:16:01 +08:00
[Kernel] Use pre-allocated output buffer for triton kernel fused_experts (#29219)
Signed-off-by: Xin Yang <xyangx@amazon.com>
This commit is contained in:
parent
c5ee430328
commit
53d7f1f601
@ -1,6 +1,7 @@
|
||||
# SPDX-License-Identifier: Apache-2.0
|
||||
# SPDX-FileCopyrightText: Copyright contributors to the vLLM project
|
||||
|
||||
|
||||
import torch
|
||||
|
||||
import vllm.model_executor.layers.fused_moe.modular_kernel as mk
|
||||
@ -12,6 +13,7 @@ from vllm.model_executor.layers.fused_moe.config import (
|
||||
from vllm.model_executor.layers.fused_moe.topk_weight_and_reduce import (
|
||||
TopKWeightAndReduceNoOP,
|
||||
)
|
||||
from vllm.model_executor.layers.fused_moe.utils import _resize_cache
|
||||
from vllm.triton_utils import tl, triton
|
||||
from vllm.utils.import_utils import has_triton_kernels
|
||||
|
||||
@ -88,14 +90,17 @@ def triton_kernel_moe_forward(
|
||||
gating_output, topk, sm_first=not renormalize
|
||||
)
|
||||
|
||||
output = torch.empty_like(hidden_states)
|
||||
|
||||
return triton_kernel_fused_experts(
|
||||
None,
|
||||
output,
|
||||
hidden_states,
|
||||
w1,
|
||||
w2,
|
||||
routing_data,
|
||||
gather_idx,
|
||||
scatter_idx,
|
||||
topk=topk,
|
||||
activation=activation,
|
||||
quant_config=quant_config,
|
||||
apply_router_weight_on_input=apply_router_weight_on_input,
|
||||
@ -113,6 +118,7 @@ def triton_kernel_fused_experts(
|
||||
routing_data, # RoutingData
|
||||
gather_indx, # GatherIndx
|
||||
scatter_indx, # ScatterIndx
|
||||
topk: int,
|
||||
activation: str = "silu",
|
||||
quant_config: FusedMoEQuantConfig | None = None,
|
||||
swiglu_alpha: float = 1.702,
|
||||
@ -120,6 +126,7 @@ def triton_kernel_fused_experts(
|
||||
apply_router_weight_on_input: bool = False,
|
||||
global_num_experts: int = -1,
|
||||
expert_map: torch.Tensor | None = None,
|
||||
intermediate_cache: torch.Tensor | None = None,
|
||||
a1q_scale: torch.Tensor | None = None,
|
||||
) -> torch.Tensor:
|
||||
if quant_config is None:
|
||||
@ -131,14 +138,30 @@ def triton_kernel_fused_experts(
|
||||
assert quant_config.w2_bias is None or quant_config.w2_bias.dtype == torch.float32
|
||||
|
||||
# Shape check, only check non-mxfp4
|
||||
assert hidden_states.ndim == 2
|
||||
assert hidden_states.shape[-1] == w1.shape[-2]
|
||||
assert w2.shape[-1] == w1.shape[1]
|
||||
|
||||
batch_dim = 1
|
||||
M, K = hidden_states.shape[-2:]
|
||||
E, _, N = w1.shape
|
||||
|
||||
if global_num_experts == -1:
|
||||
global_num_experts = E
|
||||
|
||||
if intermediate_cache is None:
|
||||
intermediate_cache = torch.empty(
|
||||
(batch_dim, M * topk, N // 2),
|
||||
device=hidden_states.device,
|
||||
dtype=hidden_states.dtype,
|
||||
)
|
||||
|
||||
# Add batch_dim to output buffer because matmul_ogs expects 3D output
|
||||
intermediate_cache = _resize_cache(
|
||||
intermediate_cache, (batch_dim, M * topk, N // 2)
|
||||
)
|
||||
output_tensor = _resize_cache(output_tensor, (batch_dim, M, K))
|
||||
|
||||
act = FusedActivation(
|
||||
FnSpecs("swiglu", triton_kernels.swiglu.swiglu_fn, ("alpha", "limit")),
|
||||
(swiglu_alpha, swiglu_limit),
|
||||
@ -146,7 +169,7 @@ def triton_kernel_fused_experts(
|
||||
)
|
||||
gammas = routing_data.gate_scal if routing_data else None
|
||||
|
||||
intermediate_cache1 = matmul_ogs(
|
||||
matmul_ogs(
|
||||
hidden_states,
|
||||
w1,
|
||||
quant_config.w1_bias,
|
||||
@ -155,10 +178,11 @@ def triton_kernel_fused_experts(
|
||||
precision_config=quant_config.w1_precision,
|
||||
gammas=gammas if apply_router_weight_on_input else None,
|
||||
fused_activation=act,
|
||||
y=intermediate_cache,
|
||||
)
|
||||
|
||||
intermediate_cache3 = matmul_ogs(
|
||||
intermediate_cache1,
|
||||
matmul_ogs(
|
||||
intermediate_cache.view(M * topk, N // 2),
|
||||
w2,
|
||||
quant_config.w2_bias,
|
||||
routing_data,
|
||||
@ -167,7 +191,8 @@ def triton_kernel_fused_experts(
|
||||
gammas=None if apply_router_weight_on_input else gammas,
|
||||
y=output_tensor,
|
||||
)
|
||||
return intermediate_cache3
|
||||
output_tensor = output_tensor.view(M, K)
|
||||
return output_tensor
|
||||
|
||||
|
||||
def make_routing_data(
|
||||
@ -221,6 +246,42 @@ class BaseOAITritonExperts(mk.FusedMoEPermuteExpertsUnpermute):
|
||||
def supports_expert_map(self) -> bool:
|
||||
return True
|
||||
|
||||
def moe_problem_size(
|
||||
self,
|
||||
a1: torch.Tensor,
|
||||
w1: torch.Tensor,
|
||||
w2: torch.Tensor,
|
||||
topk_ids: torch.Tensor,
|
||||
) -> tuple[int, int, int, int, int]:
|
||||
"""
|
||||
Extract the MoE problem size from the given tensor arguments:
|
||||
- a: The hidden states, input to the MoE layer.
|
||||
- w1: The first set of expert weights.
|
||||
- w2: The second set of expert weights.
|
||||
- topk_ids: The topk ids.
|
||||
Note: extracting the problem shape from the weight and activation
|
||||
tensors is not obvious. It needs to be done this way specifically
|
||||
due to subtle issues with particular kernels, e.g. the int4 kernels
|
||||
divide the trailing dimension by two, so it's not "correct" to
|
||||
extract N or K from the trailing dimension of w1 or w2. Similarly,
|
||||
some kernels transpose the weights, so this needs to be kept in mind.
|
||||
Note: This implementation covers most cases. However, if experts
|
||||
require a specialized implementation, like MarlinExperts, they are free
|
||||
to override this function.
|
||||
"""
|
||||
assert w1.dim() == 3 and w2.dim() == 3
|
||||
E, _, N = w1.size()
|
||||
K = a1.size(-1)
|
||||
|
||||
assert a1.dim() == 2
|
||||
assert topk_ids.size(0) == a1.size(0), f"{topk_ids.size(0)} != {a1.size(0)}"
|
||||
M = a1.size(0)
|
||||
|
||||
assert topk_ids.dim() == 2
|
||||
topk = topk_ids.size(1)
|
||||
|
||||
return E, M, N, K, topk
|
||||
|
||||
def finalize_weight_and_reduce_impl(self) -> mk.TopKWeightAndReduce:
|
||||
# Weight application and reduction happens in the fused_experts kernel.
|
||||
return TopKWeightAndReduceNoOP()
|
||||
@ -263,8 +324,8 @@ class OAITritonExperts(BaseOAITritonExperts):
|
||||
expert_tokens_meta: mk.ExpertTokensMetadata | None,
|
||||
) -> tuple[tuple[int, ...], tuple[int, ...], tuple[int, ...]]:
|
||||
# workspace are allocated inside the kernel
|
||||
workspace1 = (M, K)
|
||||
workspace2 = (0, 0)
|
||||
workspace1 = (0, 0)
|
||||
workspace2 = (M * topk, N // 2)
|
||||
output = (M, K)
|
||||
return (workspace1, workspace2, output)
|
||||
|
||||
@ -297,20 +358,21 @@ class OAITritonExperts(BaseOAITritonExperts):
|
||||
topk_ids, topk_weights, local_num_experts
|
||||
)
|
||||
|
||||
experts_output = triton_kernel_fused_experts(
|
||||
None,
|
||||
topk = topk_ids.size(1)
|
||||
triton_kernel_fused_experts(
|
||||
output,
|
||||
hidden_states,
|
||||
w1,
|
||||
w2,
|
||||
routing_data,
|
||||
gather_indx,
|
||||
scatter_indx,
|
||||
topk=topk,
|
||||
activation=activation,
|
||||
quant_config=self.quant_config,
|
||||
apply_router_weight_on_input=False,
|
||||
global_num_experts=local_num_experts,
|
||||
expert_map=None, # applied already
|
||||
intermediate_cache=workspace2,
|
||||
a1q_scale=a1q_scale,
|
||||
)
|
||||
|
||||
output.copy_(experts_output, non_blocking=True)
|
||||
|
||||
Loading…
x
Reference in New Issue
Block a user