[BugFix] : Fix Batched DeepGemm Experts (#19515)

Signed-off-by: Varun Sundar Rabindranath <vsundarr@redhat.com>
Co-authored-by: Varun Sundar Rabindranath <vsundarr@redhat.com>
This commit is contained in:
Varun Sundar Rabindranath 2025-06-12 22:43:02 -04:00 committed by GitHub
parent e6aab5de29
commit e3b12667d4
No known key found for this signature in database
GPG Key ID: B5690EEEBB952194
9 changed files with 52 additions and 32 deletions

View File

@ -47,15 +47,21 @@ class BatchedDeepGemmExperts(mk.FusedMoEPermuteExpertsUnpermute):
N: int, N: int,
K: int, K: int,
topk: int, topk: int,
num_experts: int, global_num_experts: int,
local_num_experts: int,
) -> tuple[tuple[int, ...], tuple[int, ...], tuple[int, ...], torch.dtype]: ) -> tuple[tuple[int, ...], tuple[int, ...], tuple[int, ...], torch.dtype]:
assert a.dim() == 2 assert a.dim() == 2
num_dp = self.world_size // self.dp_size # FIXME (varun): We should be able to dispatch only from the leader
# DP ranks in the case of TP > 1. At the moment, all the Ranks
# end up sending their tokens. This needs to be fixed.
num_dispatchers = self.world_size
num_experts = local_num_experts
max_num_tokens = a.size( max_num_tokens = a.size(
0) if self.max_num_tokens is None else self.max_num_tokens 0) if self.max_num_tokens is None else self.max_num_tokens
workspace13 = (num_experts, max_num_tokens * num_dp, max(K, N)) workspace13 = (num_experts, max_num_tokens * num_dispatchers,
workspace2 = (num_experts, max_num_tokens * num_dp, (N // 2)) max(K, N))
output = (num_experts, max_num_tokens * num_dp, K) workspace2 = (num_experts, max_num_tokens * num_dispatchers, (N // 2))
output = (num_experts, max_num_tokens * num_dispatchers, K)
return (workspace13, workspace2, output, a.dtype) return (workspace13, workspace2, output, a.dtype)
def apply( def apply(
@ -84,9 +90,6 @@ class BatchedDeepGemmExperts(mk.FusedMoEPermuteExpertsUnpermute):
a1q = hidden_states a1q = hidden_states
_, N, K = w1.size() _, N, K = w1.size()
if global_num_experts == -1:
global_num_experts = w1.size(0)
assert w2.size(1) == K assert w2.size(1) == K
E, max_num_tokens, N, K, top_k_num = mk._moe_problem_size( E, max_num_tokens, N, K, top_k_num = mk._moe_problem_size(

View File

@ -81,18 +81,19 @@ class BatchedTritonOrDeepGemmExperts(mk.FusedMoEPermuteExpertsUnpermute):
N: int, N: int,
K: int, K: int,
topk: int, topk: int,
num_experts: int, global_num_experts: int,
local_num_experts: int,
) -> tuple[tuple[int, ...], tuple[int, ...], tuple[int, ...], torch.dtype]: ) -> tuple[tuple[int, ...], tuple[int, ...], tuple[int, ...], torch.dtype]:
# Note: the deep gemm workspaces are strictly larger than the triton # Note: the deep gemm workspaces are strictly larger than the triton
# workspaces so we can be pessimistic here and allocate for DeepGemm # workspaces so we can be pessimistic here and allocate for DeepGemm
# even if we fall back to triton later, e.g. if expert maps are set. # even if we fall back to triton later, e.g. if expert maps are set.
if self.allow_deep_gemm and self.batched_deep_gemm_experts is not None: if self.allow_deep_gemm and self.batched_deep_gemm_experts is not None:
return self.batched_deep_gemm_experts.workspace_shapes( return self.batched_deep_gemm_experts.workspace_shapes(
a, aq, M, N, K, topk, num_experts) a, aq, M, N, K, topk, global_num_experts, local_num_experts)
else: else:
assert self.batched_triton_experts is not None assert self.batched_triton_experts is not None
return self.batched_triton_experts.workspace_shapes( return self.batched_triton_experts.workspace_shapes(
a, aq, M, N, K, topk, num_experts) a, aq, M, N, K, topk, global_num_experts, local_num_experts)
def apply( def apply(
self, self,

View File

@ -230,7 +230,8 @@ class CutlassExpertsFp8(mk.FusedMoEPermuteExpertsUnpermute):
N: int, N: int,
K: int, K: int,
topk: int, topk: int,
num_experts: int, global_num_experts: int,
local_num_experts: int,
) -> tuple[tuple[int, ...], tuple[int, ...], tuple[int, ...], torch.dtype]: ) -> tuple[tuple[int, ...], tuple[int, ...], tuple[int, ...], torch.dtype]:
workspace1: tuple[int, ...] = () workspace1: tuple[int, ...] = ()
workspace2: tuple[int, ...] = () workspace2: tuple[int, ...] = ()

View File

@ -74,15 +74,12 @@ class DeepGemmExperts(mk.FusedMoEPermuteExpertsUnpermute):
return True return True
def workspace_shapes( def workspace_shapes(
self, self, a: torch.Tensor, aq: torch.Tensor, M: int, N: int, K: int,
a: torch.Tensor, topk: int, global_num_experts: int, local_num_experts: int
aq: torch.Tensor,
M: int,
N: int,
K: int,
topk: int,
num_experts: int,
) -> tuple[tuple[int, ...], tuple[int, ...], tuple[int, ...], torch.dtype]: ) -> tuple[tuple[int, ...], tuple[int, ...], tuple[int, ...], torch.dtype]:
# We use global_num_experts due to how moe_align_block_size handles
# expert_maps.
num_experts = global_num_experts
block_m = self.block_shape[0] block_m = self.block_shape[0]
M_sum = (M * topk) + num_experts * (block_m - 1) M_sum = (M * topk) + num_experts * (block_m - 1)
M_sum = round_up(M_sum, block_m) M_sum = round_up(M_sum, block_m)

View File

@ -521,10 +521,12 @@ class BatchedExperts(mk.FusedMoEPermuteExpertsUnpermute):
N: int, N: int,
K: int, K: int,
topk: int, topk: int,
num_experts: int, global_num_experts: int,
local_num_experts: int,
) -> tuple[tuple[int, ...], tuple[int, ...], tuple[int, ...], torch.dtype]: ) -> tuple[tuple[int, ...], tuple[int, ...], tuple[int, ...], torch.dtype]:
assert a.dim() == 2 assert a.dim() == 2
num_dp = self.world_size // self.dp_size num_dp = self.dp_size
num_experts = local_num_experts
workspace13 = (num_experts, self.max_num_tokens * num_dp, K) workspace13 = (num_experts, self.max_num_tokens * num_dp, K)
workspace2 = (self.max_num_tokens * num_dp, N) workspace2 = (self.max_num_tokens * num_dp, N)
return (workspace13, workspace2, workspace13, a.dtype) return (workspace13, workspace2, workspace13, a.dtype)
@ -624,10 +626,12 @@ class BatchedTritonExperts(mk.FusedMoEPermuteExpertsUnpermute):
N: int, N: int,
K: int, K: int,
topk: int, topk: int,
num_experts: int, global_num_experts: int,
local_num_experts: int,
) -> tuple[tuple[int, ...], tuple[int, ...], tuple[int, ...], torch.dtype]: ) -> tuple[tuple[int, ...], tuple[int, ...], tuple[int, ...], torch.dtype]:
assert a.dim() == 2 assert a.dim() == 2
num_dp = self.world_size // self.dp_size num_dp = self.world_size // self.dp_size
num_experts = local_num_experts
max_num_tokens = a.size( max_num_tokens = a.size(
0) if self.max_num_tokens is None else self.max_num_tokens 0) if self.max_num_tokens is None else self.max_num_tokens
workspace13 = (num_experts, max_num_tokens * num_dp, max(K, N)) workspace13 = (num_experts, max_num_tokens * num_dp, max(K, N))

View File

@ -1553,7 +1553,8 @@ class TritonExperts(mk.FusedMoEPermuteExpertsUnpermute):
N: int, N: int,
K: int, K: int,
topk: int, topk: int,
num_experts: int, global_num_experts: int,
local_num_experts: int,
) -> tuple[tuple[int, ...], tuple[int, ...], tuple[int, ...], torch.dtype]: ) -> tuple[tuple[int, ...], tuple[int, ...], tuple[int, ...], torch.dtype]:
workspace1 = (M, topk, max(N * 2, K)) workspace1 = (M, topk, max(N * 2, K))
workspace2 = (M, topk, N) workspace2 = (M, topk, N)

View File

@ -194,7 +194,8 @@ class FusedMoEPermuteExpertsUnpermute(ABC):
N: int, N: int,
K: int, K: int,
topk: int, topk: int,
num_experts: int, global_num_experts: int,
local_num_experts: int,
) -> tuple[tuple[int, ...], tuple[int, ...], tuple[int, ...], torch.dtype]: ) -> tuple[tuple[int, ...], tuple[int, ...], tuple[int, ...], torch.dtype]:
""" """
Compute the shapes for the temporary and final outputs of the two gemms Compute the shapes for the temporary and final outputs of the two gemms
@ -372,8 +373,9 @@ class FusedMoEModularKernel(torch.nn.Module):
a1 = hidden_states a1 = hidden_states
output = a1 if inplace else torch.zeros_like(a1) output = a1 if inplace else torch.zeros_like(a1)
local_num_experts = w1.size(0)
if global_num_experts == -1: if global_num_experts == -1:
global_num_experts = w1.size(0) global_num_experts = local_num_experts
(a1q, a1q_scale, expert_num_tokens, _expert_topk_ids, (a1q, a1q_scale, expert_num_tokens, _expert_topk_ids,
_expert_topk_weights) = self.prepare_finalize.prepare( _expert_topk_weights) = self.prepare_finalize.prepare(
@ -408,16 +410,19 @@ class FusedMoEModularKernel(torch.nn.Module):
if num_chunks == 1: if num_chunks == 1:
(workspace13_shape, workspace2_shape, fused_out_shape, (workspace13_shape, workspace2_shape, fused_out_shape,
workspace_dtype) = self.fused_experts.workspace_shapes( workspace_dtype) = self.fused_experts.workspace_shapes(
a1, a1q, M, N, K, top_k, global_num_experts) a1, a1q, M, N, K, top_k, global_num_experts,
local_num_experts)
else: else:
# Use the full M to get the final output shape. # Use the full M to get the final output shape.
_, _, fused_out_shape, _ = ( _, _, fused_out_shape, _ = (
self.fused_experts.workspace_shapes( self.fused_experts.workspace_shapes(
a1, a1q, M, N, K, top_k, global_num_experts)) a1, a1q, M, N, K, top_k, global_num_experts,
local_num_experts))
# Use the CHUNK_SIZE to get the workspace shapes. # Use the CHUNK_SIZE to get the workspace shapes.
workspace13_shape, workspace2_shape, _, workspace_dtype = ( workspace13_shape, workspace2_shape, _, workspace_dtype = (
self.fused_experts.workspace_shapes( self.fused_experts.workspace_shapes(
a1, a1q, CHUNK_SIZE, N, K, top_k, global_num_experts)) a1, a1q, CHUNK_SIZE, N, K, top_k, global_num_experts,
local_num_experts))
# We can reuse the memory between cache1 and cache3 because by the # We can reuse the memory between cache1 and cache3 because by the
# time we need cache3, we're done with cache1. # time we need cache3, we're done with cache1.

View File

@ -159,6 +159,12 @@ def moe_align_block_size(
Aligns the token distribution across experts to be compatible with block Aligns the token distribution across experts to be compatible with block
size for matrix multiplication. size for matrix multiplication.
Note: In the case of expert_parallel, moe_align_block_size initially
considers all experts as valid and aligns all tokens appropriately.
Before the function returns it marks the experts_ids that are not in
the current GPU rank as -1 so the MoE matmuls could skip those blocks.
This requires the num_experts input arg to be the num global experts.
Parameters: Parameters:
- topk_ids: A tensor of shape [total_tokens, top_k] representing the - topk_ids: A tensor of shape [total_tokens, top_k] representing the
top-k expert indices for each token. top-k expert indices for each token.

View File

@ -48,7 +48,8 @@ class TritonOrDeepGemmExperts(mk.FusedMoEPermuteExpertsUnpermute):
N: int, N: int,
K: int, K: int,
topk: int, topk: int,
num_experts: int, global_num_experts: int,
local_num_experts: int,
) -> tuple[tuple[int, ...], tuple[int, ...], tuple[int, ...], torch.dtype]: ) -> tuple[tuple[int, ...], tuple[int, ...], tuple[int, ...], torch.dtype]:
# Note: the deep gemm workspaces are strictly larger than the triton # Note: the deep gemm workspaces are strictly larger than the triton
# workspaces so we can be pessimistic here and allocate for DeepGemm # workspaces so we can be pessimistic here and allocate for DeepGemm
@ -56,10 +57,11 @@ class TritonOrDeepGemmExperts(mk.FusedMoEPermuteExpertsUnpermute):
if self.allow_deep_gemm and _valid_deep_gemm_shape(M, N, K): if self.allow_deep_gemm and _valid_deep_gemm_shape(M, N, K):
assert self.deep_gemm_expert is not None assert self.deep_gemm_expert is not None
return self.deep_gemm_expert.workspace_shapes( return self.deep_gemm_expert.workspace_shapes(
a, aq, M, N, K, topk, num_experts) a, aq, M, N, K, topk, global_num_experts, local_num_experts)
else: else:
return self.triton_expert.workspace_shapes(a, aq, M, N, K, topk, return self.triton_expert.workspace_shapes(a, aq, M, N, K, topk,
num_experts) global_num_experts,
local_num_experts)
def apply( def apply(
self, self,