[Kernel] DeepGemm MoE : Integrate triton permute / unpermute kernels (#20903)

Signed-off-by: Varun Sundar Rabindranath <vsundarr@redhat.com>
Co-authored-by: Varun Sundar Rabindranath <vsundarr@redhat.com>
This commit is contained in:
Varun Sundar Rabindranath 2025-07-17 13:40:37 +05:30 committed by GitHub
parent fdc5b43d20
commit 11dfdf21bf
No known key found for this signature in database
GPG Key ID: B5690EEEBB952194
10 changed files with 490 additions and 58 deletions

View File

@ -85,7 +85,6 @@ def make_config_arg_parser(description: str):
help="num topk")
parser.add_argument(
"--fused-moe-chunk-size",
nargs="+",
type=int,
help="Fused moe chunk size used for the non-batched fused experts impl."
)

View File

@ -239,6 +239,7 @@ class BatchedDeepGemmExperts(mk.FusedMoEPermuteExpertsUnpermute):
topk: int,
global_num_experts: int,
local_num_experts: int,
expert_tokens_metadata: Optional[mk.ExpertTokensMetadata],
) -> tuple[tuple[int, ...], tuple[int, ...], tuple[int, ...], torch.dtype]:
assert a.dim() == 2
# FIXME (varun): We should be able to dispatch only from the leader

View File

@ -116,6 +116,7 @@ class BatchedTritonOrDeepGemmExperts(mk.FusedMoEPermuteExpertsUnpermute):
topk: int,
global_num_experts: int,
local_num_experts: int,
expert_tokens_metadata: Optional[mk.ExpertTokensMetadata],
) -> tuple[tuple[int, ...], tuple[int, ...], tuple[int, ...], torch.dtype]:
# Note: the deep gemm workspaces are strictly larger than the triton
# workspaces so we can be pessimistic here and allocate for DeepGemm
@ -123,11 +124,13 @@ class BatchedTritonOrDeepGemmExperts(mk.FusedMoEPermuteExpertsUnpermute):
if self.allow_deep_gemm:
assert self.batched_deep_gemm_experts is not None
return self.batched_deep_gemm_experts.workspace_shapes(
a, aq, M, N, K, topk, global_num_experts, local_num_experts)
a, aq, M, N, K, topk, global_num_experts, local_num_experts,
expert_tokens_metadata)
else:
assert self.batched_triton_experts is not None
return self.batched_triton_experts.workspace_shapes(
a, aq, M, N, K, topk, global_num_experts, local_num_experts)
a, aq, M, N, K, topk, global_num_experts, local_num_experts,
expert_tokens_metadata)
def apply(self, output: torch.Tensor, hidden_states: torch.Tensor,
w1: torch.Tensor, w2: torch.Tensor, topk_weights: torch.Tensor,

View File

@ -271,6 +271,7 @@ class CutlassExpertsFp8(mk.FusedMoEPermuteExpertsUnpermute):
topk: int,
global_num_experts: int,
local_num_experts: int,
expert_tokens_meta: Optional[mk.ExpertTokensMetadata],
) -> tuple[tuple[int, ...], tuple[int, ...], tuple[int, ...], torch.dtype]:
workspace1: tuple[int, ...] = ()
workspace2: tuple[int, ...] = ()

View File

@ -8,16 +8,16 @@ import torch
import vllm.model_executor.layers.fused_moe.modular_kernel as mk
from vllm.logger import init_logger
from vllm.model_executor.layers.fused_moe.config import FusedMoEQuantConfig
from vllm.model_executor.layers.fused_moe.moe_permute_unpermute import (
_moe_permute)
from vllm.model_executor.layers.fused_moe.deep_gemm_utils import (
compute_aligned_M, deepgemm_moe_permute, deepgemm_unpermute_and_reduce)
from vllm.model_executor.layers.fused_moe.prepare_finalize import (
MoEPrepareAndFinalizeNoEP)
from vllm.model_executor.layers.fused_moe.topk_weight_and_reduce import (
TopKWeightAndReduceContiguous, TopKWeightAndReduceNoOP)
TopKWeightAndReduceNoOP)
from vllm.model_executor.layers.fused_moe.utils import _resize_cache
from vllm.model_executor.layers.quantization.utils.fp8_utils import (
per_token_group_quant_fp8)
from vllm.utils import has_deep_gemm, round_up
from vllm.utils import has_deep_gemm
from vllm.utils.deep_gemm import m_grouped_fp8_gemm_nt_contiguous
logger = init_logger(__name__)
@ -93,18 +93,25 @@ class DeepGemmExperts(mk.FusedMoEPermuteExpertsUnpermute):
return TopKWeightAndReduceNoOP()
def workspace_shapes(
self, a: torch.Tensor, aq: torch.Tensor, M: int, N: int, K: int,
topk: int, global_num_experts: int, local_num_experts: int
self,
a: torch.Tensor,
aq: torch.Tensor,
M: int,
N: int,
K: int,
topk: int,
global_num_experts: int,
local_num_experts: int,
expert_tokens_meta: Optional[mk.ExpertTokensMetadata],
) -> tuple[tuple[int, ...], tuple[int, ...], tuple[int, ...], torch.dtype]:
assert self.block_shape is not None
# We use global_num_experts due to how moe_align_block_size handles
# expert_maps.
num_experts = global_num_experts
block_m = self.block_shape[0]
M_sum = (M * topk) + num_experts * (block_m - 1)
M_sum = round_up(M_sum, block_m)
workspace1 = (M_sum, max(N // 2, K))
workspace2 = (M_sum, max(N, K))
M_sum = compute_aligned_M(M, topk, local_num_experts, block_m,
expert_tokens_meta)
assert M_sum % block_m == 0
workspace1 = (M_sum, max(N, K))
workspace2 = (M_sum, max(N // 2, K))
output = (M, K)
return (workspace1, workspace2, output, a.dtype)
@ -131,43 +138,40 @@ class DeepGemmExperts(mk.FusedMoEPermuteExpertsUnpermute):
apply_router_weight_on_input: bool,
):
assert self.block_shape is not None
assert a1q_scale is not None
a1q = hidden_states
_, N, K = w1.size()
M, _ = output.size()
num_topk = topk_ids.size(1)
local_num_experts = w1.size(0)
if global_num_experts == -1:
global_num_experts = w1.size(0)
global_num_experts = local_num_experts
assert w2.size(1) == K
a1q, a1q_scale, _, expert_ids, inv_perm = _moe_permute(
a1q,
a1q_scale,
topk_ids,
global_num_experts,
expert_map,
self.block_shape[0],
)
M_sum = compute_aligned_M(M=topk_ids.size(0),
num_topk=topk_ids.size(1),
local_num_experts=local_num_experts,
alignment=deep_gemm_block_shape()[0],
expert_tokens_meta=expert_tokens_meta)
if expert_map is not None:
# DeepGemm (Grouped Contiguous) kernel needs a valid B index
# for all rows of A. To that effect, simply compute with
# the 0th weight matrix.
# Note that this relies on the fact that corresponding topk
# weights would be 0 during weight multiplication.
expert_ids = torch.where(expert_ids == -1, 0, expert_ids)
# Note: M_sum is different than the pre-permuted shape of a1q.
M_sum = a1q.size(0)
mm1_out = _resize_cache(workspace2, (M_sum, N))
act_out = _resize_cache(workspace13, (M_sum, N // 2))
quant_out = _resize_cache(workspace2.view(dtype=torch.float8_e4m3fn),
a1q_perm = _resize_cache(workspace2.view(dtype=torch.float8_e4m3fn),
(M_sum, K))
mm1_out = _resize_cache(workspace13, (M_sum, N))
act_out = _resize_cache(workspace2, (M_sum, N // 2))
quant_out = _resize_cache(workspace13.view(dtype=torch.float8_e4m3fn),
(M_sum, N // 2))
mm2_out = _resize_cache(workspace13, (M_sum, K))
perm_out = _resize_cache(workspace2, (M * num_topk, K))
mm2_out = _resize_cache(workspace2, (M_sum, K))
a1q, a1q_scale, expert_ids, inv_perm = deepgemm_moe_permute(
aq=a1q,
aq_scale=a1q_scale,
topk_ids=topk_ids,
local_num_experts=local_num_experts,
expert_map=expert_map,
expert_tokens_meta=expert_tokens_meta,
aq_out=a1q_perm)
assert a1q.size(0) == M_sum
m_grouped_fp8_gemm_nt_contiguous((a1q, a1q_scale), (w1, w1_scale),
mm1_out, expert_ids)
@ -183,14 +187,15 @@ class DeepGemmExperts(mk.FusedMoEPermuteExpertsUnpermute):
m_grouped_fp8_gemm_nt_contiguous((a2q, a2q_scale), (w2, w2_scale),
mm2_out, expert_ids)
torch.index_select(mm2_out, 0, inv_perm, out=perm_out)
if apply_router_weight_on_input:
topk_weights = torch.ones_like(topk_weights)
TopKWeightAndReduceContiguous().apply(
output=output,
fused_expert_output=perm_out,
topk_weights=topk_weights,
topk_ids=topk_ids,
apply_router_weight_on_input=apply_router_weight_on_input)
deepgemm_unpermute_and_reduce(a=mm2_out,
topk_ids=topk_ids,
topk_weights=topk_weights,
inv_perm=inv_perm,
expert_map=expert_map,
output=output)
def deep_gemm_moe_fp8(

View File

@ -0,0 +1,413 @@
# SPDX-License-Identifier: Apache-2.0
# SPDX-FileCopyrightText: Copyright contributors to the vLLM project
"""
Taken from https://github.com/ModelTC/LightLLM/blob/8ed97c74c18f11505b048b1ba00ba5c0cef8bff6/lightllm/common/fused_moe/deepep_scatter_gather.py
and updated to fit vllm needs and terminology.
"""
import functools
from typing import Optional
import torch
import vllm.model_executor.layers.fused_moe.modular_kernel as mk
from vllm.model_executor.layers.fused_moe.utils import count_expert_num_tokens
from vllm.triton_utils import tl, triton
from vllm.utils import round_up
@functools.cache
def deep_gemm_block_shape() -> list[int]:
# Lazy import to avoid CUDA initialization problems.
import deep_gemm as dg
block = dg.get_m_alignment_for_contiguous_layout()
return [block, block]
def expert_num_tokens_round_up_and_sum(expert_num_tokens: torch.Tensor,
alignment: int) -> int:
# Round up each element in expert_num_tokens to the nearest multiple of
# alignment.
ent = (expert_num_tokens.to(torch.int64) +
(alignment - 1)) // alignment * alignment
return torch.sum(ent).item()
def compute_aligned_M(M: int, num_topk: int, local_num_experts: int,
alignment: int,
expert_tokens_meta: Optional[mk.ExpertTokensMetadata]):
if ((expert_tokens_meta is not None)
and (expert_tokens_meta.expert_num_tokens_cpu is not None)):
return expert_num_tokens_round_up_and_sum(
expert_tokens_meta.expert_num_tokens_cpu, alignment=alignment)
# expert_num_tokens information is not available on the cpu.
# compute the max required size.
M_sum = (M * num_topk) + local_num_experts * (alignment - 1)
M_sum = round_up(M_sum, alignment)
return M_sum
@triton.jit
def apply_expert_map(expert_id, expert_map):
if expert_id != -1:
expert_id = tl.load(expert_map + expert_id).to(tl.int64)
return expert_id
@triton.jit
def round_up_128(x: int) -> int:
y = 128
return ((x + y - 1) // y) * y
@triton.jit
def _fwd_kernel_ep_scatter_1(
num_recv_tokens_per_expert,
expert_start_loc,
m_indices,
num_experts: tl.constexpr,
BLOCK_E: tl.constexpr,
BLOCK_EXPERT_NUM: tl.constexpr,
):
cur_expert = tl.program_id(0)
offset_cumsum = tl.arange(0, BLOCK_EXPERT_NUM)
tokens_per_expert = tl.load(num_recv_tokens_per_expert + offset_cumsum,
mask=offset_cumsum < num_experts,
other=0)
tokens_per_expert = round_up_128(tokens_per_expert)
cumsum = tl.cumsum(tokens_per_expert) - tokens_per_expert
tl.store(expert_start_loc + offset_cumsum,
cumsum,
mask=offset_cumsum < num_experts)
cur_expert_start = tl.load(expert_start_loc + cur_expert)
cur_expert_token_num = tl.load(num_recv_tokens_per_expert + cur_expert)
m_indices_start_ptr = m_indices + cur_expert_start
off_expert = tl.arange(0, BLOCK_E)
for start_m in tl.range(0, cur_expert_token_num, BLOCK_E, num_stages=4):
tl.store(
m_indices_start_ptr + start_m + off_expert,
cur_expert,
)
@triton.jit
def _fwd_kernel_ep_scatter_2(
total_token_num,
expert_start_loc,
recv_x,
recv_x_stride0,
recv_x_stride1,
recv_x_scale,
recv_x_scale_stride0,
recv_x_scale_stride1,
recv_topk,
recv_topk_stride0,
recv_topk_stride1,
output_tensor,
output_tensor_stride0,
output_tensor_stride1,
output_tensor_scale,
output_tensor_scale_stride0,
output_tensor_scale_stride1,
output_index,
output_index_stride0,
output_index_stride1,
topk_num: tl.constexpr,
expert_map,
HAS_EXPERT_MAP: tl.constexpr,
HIDDEN_SIZE: tl.constexpr,
HIDDEN_SIZE_PAD: tl.constexpr,
SCALE_HIDDEN_SIZE: tl.constexpr,
SCALE_HIDDEN_SIZE_PAD: tl.constexpr,
):
start_token_id = tl.program_id(0)
grid_num = tl.num_programs(0)
offset_in = tl.arange(0, HIDDEN_SIZE_PAD)
mask = offset_in < HIDDEN_SIZE
offset_in_s = tl.arange(0, SCALE_HIDDEN_SIZE_PAD)
mask_s = offset_in_s < SCALE_HIDDEN_SIZE
for token_id in range(start_token_id, total_token_num, grid_num):
to_copy = tl.load(recv_x + token_id * recv_x_stride0 + offset_in,
mask=mask)
to_copy_s = tl.load(recv_x_scale + token_id * recv_x_scale_stride0 +
offset_in_s,
mask=mask_s)
for topk_index in tl.range(0, topk_num, 1, num_stages=4):
expert_id = tl.load(recv_topk + token_id * recv_topk_stride0 +
topk_index)
if HAS_EXPERT_MAP:
expert_id = apply_expert_map(expert_id, expert_map)
if expert_id >= 0:
dest_token_index = tl.atomic_add(expert_start_loc + expert_id,
1)
tl.store(
output_index + token_id * output_index_stride0 +
topk_index, dest_token_index)
output_tensor_ptr = (output_tensor +
dest_token_index * output_tensor_stride0)
output_tensor_scale_ptr = (
output_tensor_scale +
dest_token_index * output_tensor_scale_stride0)
tl.store(output_tensor_ptr + offset_in, to_copy, mask=mask)
tl.store(output_tensor_scale_ptr + offset_in_s,
to_copy_s,
mask=mask_s)
@torch.no_grad()
def ep_scatter(
recv_x: torch.Tensor,
recv_x_scale: torch.Tensor,
recv_topk: torch.Tensor,
num_recv_tokens_per_expert: torch.Tensor,
expert_map: Optional[torch.Tensor],
expert_start_loc: torch.Tensor,
output_tensor: torch.Tensor,
output_tensor_scale: torch.Tensor,
m_indices: torch.Tensor,
output_index: torch.Tensor,
):
BLOCK_E = 128 # token num of per expert is aligned to 128
BLOCK_D = 128 # block size of quantization
num_warps = 8
num_experts = num_recv_tokens_per_expert.shape[0]
hidden_size = recv_x.shape[1]
# grid = (triton.cdiv(hidden_size, BLOCK_D), num_experts)
grid = num_experts
assert m_indices.shape[0] % BLOCK_E == 0
_fwd_kernel_ep_scatter_1[(grid, )](
num_recv_tokens_per_expert,
expert_start_loc,
m_indices,
num_experts=num_experts,
num_warps=num_warps,
BLOCK_E=BLOCK_E,
BLOCK_EXPERT_NUM=triton.next_power_of_2(num_experts),
)
grid = min(recv_topk.shape[0], 1024 * 8)
_fwd_kernel_ep_scatter_2[(grid, )](
recv_topk.shape[0],
expert_start_loc,
recv_x,
recv_x.stride(0),
recv_x.stride(1),
recv_x_scale,
recv_x_scale.stride(0),
recv_x_scale.stride(1),
recv_topk,
recv_topk.stride(0),
recv_topk.stride(1),
output_tensor,
output_tensor.stride(0),
output_tensor.stride(1),
output_tensor_scale,
output_tensor_scale.stride(0),
output_tensor_scale.stride(1),
output_index,
output_index.stride(0),
output_index.stride(1),
topk_num=recv_topk.shape[1],
expert_map=expert_map,
HAS_EXPERT_MAP=expert_map is not None,
num_warps=num_warps,
HIDDEN_SIZE=hidden_size,
HIDDEN_SIZE_PAD=triton.next_power_of_2(hidden_size),
SCALE_HIDDEN_SIZE=hidden_size // BLOCK_D,
SCALE_HIDDEN_SIZE_PAD=triton.next_power_of_2(hidden_size // BLOCK_D),
)
return
@triton.jit
def _fwd_kernel_ep_gather(
total_token_num,
input_tensor,
input_tensor_stride0,
input_tensor_stride1,
recv_topk_ids,
recv_topk_ids_stride0,
recv_topk_ids_stride1,
recv_topk_weight,
recv_topk_weight_stride0,
recv_topk_weight_stride1,
input_index,
input_index_stride0,
input_index_stride1,
output_tensor,
output_tensor_stride0,
output_tensor_stride1,
topk_num: tl.constexpr,
expert_map,
HAS_EXPERT_MAP: tl.constexpr,
BLOCK_D: tl.constexpr,
):
cur_block = tl.program_id(0)
start_cur_token = tl.program_id(1)
grid_num = tl.num_programs(1)
for cur_token in range(start_cur_token, total_token_num, grid_num):
off_d = tl.arange(0, BLOCK_D)
accumulator = tl.zeros([BLOCK_D], dtype=tl.float32)
for topk_index in range(0, topk_num):
expert_id = tl.load(recv_topk_ids +
cur_token * recv_topk_ids_stride0 + topk_index)
if HAS_EXPERT_MAP:
expert_id = apply_expert_map(expert_id, expert_map)
if expert_id >= 0:
source_token_index = tl.load(input_index +
cur_token * input_index_stride0 +
topk_index)
acc_weight = tl.load(recv_topk_weight +
cur_token * recv_topk_weight_stride0 +
topk_index)
tmp = tl.load(input_tensor +
source_token_index * input_tensor_stride0 +
cur_block * BLOCK_D + off_d)
accumulator += tmp.to(tl.float32) * acc_weight
tl.store(
output_tensor + cur_token * output_tensor_stride0 +
cur_block * BLOCK_D + off_d,
accumulator.to(output_tensor.dtype.element_ty),
)
@torch.no_grad()
def ep_gather(
input_tensor: torch.Tensor,
recv_topk_ids: torch.Tensor,
recv_topk_weight: torch.Tensor,
input_index: torch.Tensor,
expert_map: Optional[torch.Tensor],
output_tensor: torch.Tensor,
):
num_warps = 2
num_tokens = output_tensor.shape[0]
hidden_size = input_tensor.shape[1]
BLOCK_D = min(hidden_size, 1024)
assert hidden_size % BLOCK_D == 0
grid = (triton.cdiv(hidden_size, BLOCK_D), min(num_tokens, 1024))
_fwd_kernel_ep_gather[grid](
num_tokens,
input_tensor,
input_tensor.stride(0),
input_tensor.stride(1),
recv_topk_ids,
recv_topk_ids.stride(0),
recv_topk_ids.stride(1),
recv_topk_weight,
recv_topk_weight.stride(0),
recv_topk_weight.stride(1),
input_index,
input_index.stride(0),
input_index.stride(1),
output_tensor,
output_tensor.stride(0),
output_tensor.stride(1),
topk_num=recv_topk_ids.shape[1],
expert_map=expert_map,
HAS_EXPERT_MAP=expert_map is not None,
num_warps=num_warps,
BLOCK_D=BLOCK_D,
)
return
def deepgemm_moe_permute(aq: torch.Tensor,
aq_scale: torch.Tensor,
topk_ids: torch.Tensor,
local_num_experts: int,
expert_map: Optional[torch.Tensor],
expert_tokens_meta: Optional[mk.ExpertTokensMetadata],
aq_out: Optional[torch.Tensor] = None):
assert aq.ndim == 2
assert topk_ids.dtype.is_signed, (
"The kernel uses -1 to represent invalid topk_ids")
H = aq.size(1)
device = aq.device
block_m = deep_gemm_block_shape()[0]
block_k = deep_gemm_block_shape()[1]
M_sum = compute_aligned_M(M=topk_ids.size(0),
num_topk=topk_ids.size(1),
local_num_experts=local_num_experts,
alignment=block_m,
expert_tokens_meta=expert_tokens_meta)
expert_start_loc = torch.empty((local_num_experts),
device=device,
dtype=torch.int32)
assert aq_out is None or aq_out.shape == (M_sum, H)
if aq_out is None:
aq_out = torch.empty((M_sum, H), device=device, dtype=aq.dtype)
aq_scale_out = torch.empty((M_sum, H // block_k),
device=device,
dtype=torch.float32)
maybe_has_empty_blocks = ((expert_tokens_meta is None)
or (expert_tokens_meta.expert_num_tokens_cpu
is None))
expert_ids_init = torch.zeros if maybe_has_empty_blocks else torch.empty
expert_ids = expert_ids_init((M_sum), device=device, dtype=torch.int32)
inv_perm = torch.empty(topk_ids.shape, device=device, dtype=torch.int32)
expert_num_tokens = None
if expert_tokens_meta is not None:
expert_num_tokens = expert_tokens_meta.expert_num_tokens
else:
expert_num_tokens = count_expert_num_tokens(topk_ids,
local_num_experts,
expert_map)
ep_scatter(recv_x=aq,
recv_x_scale=aq_scale,
recv_topk=topk_ids,
num_recv_tokens_per_expert=expert_num_tokens,
expert_start_loc=expert_start_loc,
expert_map=expert_map,
output_tensor=aq_out,
output_tensor_scale=aq_scale_out,
m_indices=expert_ids,
output_index=inv_perm)
return aq_out, aq_scale_out, expert_ids, inv_perm
def deepgemm_unpermute_and_reduce(
a: torch.Tensor, # Grouped gemm output
topk_ids: torch.Tensor,
topk_weights: torch.Tensor,
inv_perm: torch.Tensor,
expert_map: Optional[torch.Tensor],
output: torch.Tensor):
return ep_gather(input_tensor=a,
recv_topk_ids=topk_ids,
recv_topk_weight=topk_weights,
input_index=inv_perm,
expert_map=expert_map,
output_tensor=output)

View File

@ -677,6 +677,7 @@ class NaiveBatchedExperts(mk.FusedMoEPermuteExpertsUnpermute):
topk: int,
global_num_experts: int,
local_num_experts: int,
expert_tokens_meta: Optional[mk.ExpertTokensMetadata],
) -> tuple[tuple[int, ...], tuple[int, ...], tuple[int, ...], torch.dtype]:
assert a.dim() == 2
num_dp = self.num_dispatchers
@ -889,6 +890,7 @@ class BatchedTritonExperts(mk.FusedMoEPermuteExpertsUnpermute):
topk: int,
global_num_experts: int,
local_num_experts: int,
expert_tokens_meta: Optional[mk.ExpertTokensMetadata],
) -> tuple[tuple[int, ...], tuple[int, ...], tuple[int, ...], torch.dtype]:
assert a.dim() == 2
num_dp = self.num_dispatchers

View File

@ -1618,6 +1618,7 @@ class TritonExperts(mk.FusedMoEPermuteExpertsUnpermute):
topk: int,
global_num_experts: int,
local_num_experts: int,
expert_tokens_meta: Optional[mk.ExpertTokensMetadata],
) -> tuple[tuple[int, ...], tuple[int, ...], tuple[int, ...], torch.dtype]:
workspace1 = (M, topk, max(N // 2, K))
workspace2 = (M, topk, max(N, K))

View File

@ -317,6 +317,7 @@ class FusedMoEPermuteExpertsUnpermute(ABC):
topk: int,
global_num_experts: int,
local_num_experts: int,
expert_tokens_meta: Optional[ExpertTokensMetadata],
) -> tuple[tuple[int, ...], tuple[int, ...], tuple[int, ...], torch.dtype]:
"""
Compute the shapes for the temporary and final outputs of the two gemms
@ -479,7 +480,8 @@ class FusedMoEModularKernel(torch.nn.Module):
(workspace13_shape, workspace2_shape, fused_out_shape,
workspace_dtype) = self.fused_experts.workspace_shapes(
a1, a1q, M, N, K, top_k, global_num_experts, local_num_experts)
a1, a1q, M, N, K, top_k, global_num_experts, local_num_experts,
expert_tokens_meta)
# We can reuse the memory between cache1 and cache3 because by the
# time we need cache3, we're done with cache1.
@ -572,10 +574,9 @@ class FusedMoEModularKernel(torch.nn.Module):
assert num_chunks > 1
# Construct the entire output that can then be processed in chunks.
(_, _, fused_out_shape,
_) = self.fused_experts.workspace_shapes(a1, a1q, M, N, K, top_k,
global_num_experts,
local_num_experts)
(_, _, fused_out_shape, _) = self.fused_experts.workspace_shapes(
a1, a1q, M, N, K, top_k, global_num_experts, local_num_experts,
expert_tokens_meta)
fused_out = torch.empty(fused_out_shape,
device=a1q.device,
dtype=a1.dtype)
@ -613,8 +614,11 @@ class FusedMoEModularKernel(torch.nn.Module):
need_expert_num_tokens_cpu = (
full_expert_tokens_meta.expert_num_tokens_cpu is not None)
if need_expert_num_tokens_cpu:
# This is blocking as some implementations need the count
# on the CPU to determine appropriate input/out fused-moe
# buffers
c_expert_num_tokens_cpu = c_expert_num_tokens.to(
"cpu", non_blocking=True)
"cpu", non_blocking=False)
return ExpertTokensMetadata(
expert_num_tokens=c_expert_num_tokens,

View File

@ -102,6 +102,7 @@ class TritonOrDeepGemmExperts(mk.FusedMoEPermuteExpertsUnpermute):
topk: int,
global_num_experts: int,
local_num_experts: int,
expert_tokens_meta: Optional[mk.ExpertTokensMetadata],
) -> tuple[tuple[int, ...], tuple[int, ...], tuple[int, ...], torch.dtype]:
# Note: the deep gemm workspaces are strictly larger than the triton
# workspaces so we can be pessimistic here and allocate for DeepGemm
@ -110,11 +111,13 @@ class TritonOrDeepGemmExperts(mk.FusedMoEPermuteExpertsUnpermute):
or is_blackwell_deep_gemm_used()):
assert self.deep_gemm_expert is not None
return self.deep_gemm_expert.workspace_shapes(
a, aq, M, N, K, topk, global_num_experts, local_num_experts)
a, aq, M, N, K, topk, global_num_experts, local_num_experts,
expert_tokens_meta)
else:
return self.triton_expert.workspace_shapes(a, aq, M, N, K, topk,
global_num_experts,
local_num_experts)
local_num_experts,
expert_tokens_meta)
def apply(
self,