mirror of
https://git.datalinker.icu/vllm-project/vllm.git
synced 2025-12-17 12:05:38 +08:00
[BugFix] : Fix Batched DeepGemm Experts (#19515)
Signed-off-by: Varun Sundar Rabindranath <vsundarr@redhat.com> Co-authored-by: Varun Sundar Rabindranath <vsundarr@redhat.com>
This commit is contained in:
parent
e6aab5de29
commit
e3b12667d4
@ -47,15 +47,21 @@ class BatchedDeepGemmExperts(mk.FusedMoEPermuteExpertsUnpermute):
|
|||||||
N: int,
|
N: int,
|
||||||
K: int,
|
K: int,
|
||||||
topk: int,
|
topk: int,
|
||||||
num_experts: int,
|
global_num_experts: int,
|
||||||
|
local_num_experts: int,
|
||||||
) -> tuple[tuple[int, ...], tuple[int, ...], tuple[int, ...], torch.dtype]:
|
) -> tuple[tuple[int, ...], tuple[int, ...], tuple[int, ...], torch.dtype]:
|
||||||
assert a.dim() == 2
|
assert a.dim() == 2
|
||||||
num_dp = self.world_size // self.dp_size
|
# FIXME (varun): We should be able to dispatch only from the leader
|
||||||
|
# DP ranks in the case of TP > 1. At the moment, all the Ranks
|
||||||
|
# end up sending their tokens. This needs to be fixed.
|
||||||
|
num_dispatchers = self.world_size
|
||||||
|
num_experts = local_num_experts
|
||||||
max_num_tokens = a.size(
|
max_num_tokens = a.size(
|
||||||
0) if self.max_num_tokens is None else self.max_num_tokens
|
0) if self.max_num_tokens is None else self.max_num_tokens
|
||||||
workspace13 = (num_experts, max_num_tokens * num_dp, max(K, N))
|
workspace13 = (num_experts, max_num_tokens * num_dispatchers,
|
||||||
workspace2 = (num_experts, max_num_tokens * num_dp, (N // 2))
|
max(K, N))
|
||||||
output = (num_experts, max_num_tokens * num_dp, K)
|
workspace2 = (num_experts, max_num_tokens * num_dispatchers, (N // 2))
|
||||||
|
output = (num_experts, max_num_tokens * num_dispatchers, K)
|
||||||
return (workspace13, workspace2, output, a.dtype)
|
return (workspace13, workspace2, output, a.dtype)
|
||||||
|
|
||||||
def apply(
|
def apply(
|
||||||
@ -84,9 +90,6 @@ class BatchedDeepGemmExperts(mk.FusedMoEPermuteExpertsUnpermute):
|
|||||||
a1q = hidden_states
|
a1q = hidden_states
|
||||||
_, N, K = w1.size()
|
_, N, K = w1.size()
|
||||||
|
|
||||||
if global_num_experts == -1:
|
|
||||||
global_num_experts = w1.size(0)
|
|
||||||
|
|
||||||
assert w2.size(1) == K
|
assert w2.size(1) == K
|
||||||
|
|
||||||
E, max_num_tokens, N, K, top_k_num = mk._moe_problem_size(
|
E, max_num_tokens, N, K, top_k_num = mk._moe_problem_size(
|
||||||
|
|||||||
@ -81,18 +81,19 @@ class BatchedTritonOrDeepGemmExperts(mk.FusedMoEPermuteExpertsUnpermute):
|
|||||||
N: int,
|
N: int,
|
||||||
K: int,
|
K: int,
|
||||||
topk: int,
|
topk: int,
|
||||||
num_experts: int,
|
global_num_experts: int,
|
||||||
|
local_num_experts: int,
|
||||||
) -> tuple[tuple[int, ...], tuple[int, ...], tuple[int, ...], torch.dtype]:
|
) -> tuple[tuple[int, ...], tuple[int, ...], tuple[int, ...], torch.dtype]:
|
||||||
# Note: the deep gemm workspaces are strictly larger than the triton
|
# Note: the deep gemm workspaces are strictly larger than the triton
|
||||||
# workspaces so we can be pessimistic here and allocate for DeepGemm
|
# workspaces so we can be pessimistic here and allocate for DeepGemm
|
||||||
# even if we fall back to triton later, e.g. if expert maps are set.
|
# even if we fall back to triton later, e.g. if expert maps are set.
|
||||||
if self.allow_deep_gemm and self.batched_deep_gemm_experts is not None:
|
if self.allow_deep_gemm and self.batched_deep_gemm_experts is not None:
|
||||||
return self.batched_deep_gemm_experts.workspace_shapes(
|
return self.batched_deep_gemm_experts.workspace_shapes(
|
||||||
a, aq, M, N, K, topk, num_experts)
|
a, aq, M, N, K, topk, global_num_experts, local_num_experts)
|
||||||
else:
|
else:
|
||||||
assert self.batched_triton_experts is not None
|
assert self.batched_triton_experts is not None
|
||||||
return self.batched_triton_experts.workspace_shapes(
|
return self.batched_triton_experts.workspace_shapes(
|
||||||
a, aq, M, N, K, topk, num_experts)
|
a, aq, M, N, K, topk, global_num_experts, local_num_experts)
|
||||||
|
|
||||||
def apply(
|
def apply(
|
||||||
self,
|
self,
|
||||||
|
|||||||
@ -230,7 +230,8 @@ class CutlassExpertsFp8(mk.FusedMoEPermuteExpertsUnpermute):
|
|||||||
N: int,
|
N: int,
|
||||||
K: int,
|
K: int,
|
||||||
topk: int,
|
topk: int,
|
||||||
num_experts: int,
|
global_num_experts: int,
|
||||||
|
local_num_experts: int,
|
||||||
) -> tuple[tuple[int, ...], tuple[int, ...], tuple[int, ...], torch.dtype]:
|
) -> tuple[tuple[int, ...], tuple[int, ...], tuple[int, ...], torch.dtype]:
|
||||||
workspace1: tuple[int, ...] = ()
|
workspace1: tuple[int, ...] = ()
|
||||||
workspace2: tuple[int, ...] = ()
|
workspace2: tuple[int, ...] = ()
|
||||||
|
|||||||
@ -74,15 +74,12 @@ class DeepGemmExperts(mk.FusedMoEPermuteExpertsUnpermute):
|
|||||||
return True
|
return True
|
||||||
|
|
||||||
def workspace_shapes(
|
def workspace_shapes(
|
||||||
self,
|
self, a: torch.Tensor, aq: torch.Tensor, M: int, N: int, K: int,
|
||||||
a: torch.Tensor,
|
topk: int, global_num_experts: int, local_num_experts: int
|
||||||
aq: torch.Tensor,
|
|
||||||
M: int,
|
|
||||||
N: int,
|
|
||||||
K: int,
|
|
||||||
topk: int,
|
|
||||||
num_experts: int,
|
|
||||||
) -> tuple[tuple[int, ...], tuple[int, ...], tuple[int, ...], torch.dtype]:
|
) -> tuple[tuple[int, ...], tuple[int, ...], tuple[int, ...], torch.dtype]:
|
||||||
|
# We use global_num_experts due to how moe_align_block_size handles
|
||||||
|
# expert_maps.
|
||||||
|
num_experts = global_num_experts
|
||||||
block_m = self.block_shape[0]
|
block_m = self.block_shape[0]
|
||||||
M_sum = (M * topk) + num_experts * (block_m - 1)
|
M_sum = (M * topk) + num_experts * (block_m - 1)
|
||||||
M_sum = round_up(M_sum, block_m)
|
M_sum = round_up(M_sum, block_m)
|
||||||
|
|||||||
@ -521,10 +521,12 @@ class BatchedExperts(mk.FusedMoEPermuteExpertsUnpermute):
|
|||||||
N: int,
|
N: int,
|
||||||
K: int,
|
K: int,
|
||||||
topk: int,
|
topk: int,
|
||||||
num_experts: int,
|
global_num_experts: int,
|
||||||
|
local_num_experts: int,
|
||||||
) -> tuple[tuple[int, ...], tuple[int, ...], tuple[int, ...], torch.dtype]:
|
) -> tuple[tuple[int, ...], tuple[int, ...], tuple[int, ...], torch.dtype]:
|
||||||
assert a.dim() == 2
|
assert a.dim() == 2
|
||||||
num_dp = self.world_size // self.dp_size
|
num_dp = self.dp_size
|
||||||
|
num_experts = local_num_experts
|
||||||
workspace13 = (num_experts, self.max_num_tokens * num_dp, K)
|
workspace13 = (num_experts, self.max_num_tokens * num_dp, K)
|
||||||
workspace2 = (self.max_num_tokens * num_dp, N)
|
workspace2 = (self.max_num_tokens * num_dp, N)
|
||||||
return (workspace13, workspace2, workspace13, a.dtype)
|
return (workspace13, workspace2, workspace13, a.dtype)
|
||||||
@ -624,10 +626,12 @@ class BatchedTritonExperts(mk.FusedMoEPermuteExpertsUnpermute):
|
|||||||
N: int,
|
N: int,
|
||||||
K: int,
|
K: int,
|
||||||
topk: int,
|
topk: int,
|
||||||
num_experts: int,
|
global_num_experts: int,
|
||||||
|
local_num_experts: int,
|
||||||
) -> tuple[tuple[int, ...], tuple[int, ...], tuple[int, ...], torch.dtype]:
|
) -> tuple[tuple[int, ...], tuple[int, ...], tuple[int, ...], torch.dtype]:
|
||||||
assert a.dim() == 2
|
assert a.dim() == 2
|
||||||
num_dp = self.world_size // self.dp_size
|
num_dp = self.world_size // self.dp_size
|
||||||
|
num_experts = local_num_experts
|
||||||
max_num_tokens = a.size(
|
max_num_tokens = a.size(
|
||||||
0) if self.max_num_tokens is None else self.max_num_tokens
|
0) if self.max_num_tokens is None else self.max_num_tokens
|
||||||
workspace13 = (num_experts, max_num_tokens * num_dp, max(K, N))
|
workspace13 = (num_experts, max_num_tokens * num_dp, max(K, N))
|
||||||
|
|||||||
@ -1553,7 +1553,8 @@ class TritonExperts(mk.FusedMoEPermuteExpertsUnpermute):
|
|||||||
N: int,
|
N: int,
|
||||||
K: int,
|
K: int,
|
||||||
topk: int,
|
topk: int,
|
||||||
num_experts: int,
|
global_num_experts: int,
|
||||||
|
local_num_experts: int,
|
||||||
) -> tuple[tuple[int, ...], tuple[int, ...], tuple[int, ...], torch.dtype]:
|
) -> tuple[tuple[int, ...], tuple[int, ...], tuple[int, ...], torch.dtype]:
|
||||||
workspace1 = (M, topk, max(N * 2, K))
|
workspace1 = (M, topk, max(N * 2, K))
|
||||||
workspace2 = (M, topk, N)
|
workspace2 = (M, topk, N)
|
||||||
|
|||||||
@ -194,7 +194,8 @@ class FusedMoEPermuteExpertsUnpermute(ABC):
|
|||||||
N: int,
|
N: int,
|
||||||
K: int,
|
K: int,
|
||||||
topk: int,
|
topk: int,
|
||||||
num_experts: int,
|
global_num_experts: int,
|
||||||
|
local_num_experts: int,
|
||||||
) -> tuple[tuple[int, ...], tuple[int, ...], tuple[int, ...], torch.dtype]:
|
) -> tuple[tuple[int, ...], tuple[int, ...], tuple[int, ...], torch.dtype]:
|
||||||
"""
|
"""
|
||||||
Compute the shapes for the temporary and final outputs of the two gemms
|
Compute the shapes for the temporary and final outputs of the two gemms
|
||||||
@ -372,8 +373,9 @@ class FusedMoEModularKernel(torch.nn.Module):
|
|||||||
a1 = hidden_states
|
a1 = hidden_states
|
||||||
output = a1 if inplace else torch.zeros_like(a1)
|
output = a1 if inplace else torch.zeros_like(a1)
|
||||||
|
|
||||||
|
local_num_experts = w1.size(0)
|
||||||
if global_num_experts == -1:
|
if global_num_experts == -1:
|
||||||
global_num_experts = w1.size(0)
|
global_num_experts = local_num_experts
|
||||||
|
|
||||||
(a1q, a1q_scale, expert_num_tokens, _expert_topk_ids,
|
(a1q, a1q_scale, expert_num_tokens, _expert_topk_ids,
|
||||||
_expert_topk_weights) = self.prepare_finalize.prepare(
|
_expert_topk_weights) = self.prepare_finalize.prepare(
|
||||||
@ -408,16 +410,19 @@ class FusedMoEModularKernel(torch.nn.Module):
|
|||||||
if num_chunks == 1:
|
if num_chunks == 1:
|
||||||
(workspace13_shape, workspace2_shape, fused_out_shape,
|
(workspace13_shape, workspace2_shape, fused_out_shape,
|
||||||
workspace_dtype) = self.fused_experts.workspace_shapes(
|
workspace_dtype) = self.fused_experts.workspace_shapes(
|
||||||
a1, a1q, M, N, K, top_k, global_num_experts)
|
a1, a1q, M, N, K, top_k, global_num_experts,
|
||||||
|
local_num_experts)
|
||||||
else:
|
else:
|
||||||
# Use the full M to get the final output shape.
|
# Use the full M to get the final output shape.
|
||||||
_, _, fused_out_shape, _ = (
|
_, _, fused_out_shape, _ = (
|
||||||
self.fused_experts.workspace_shapes(
|
self.fused_experts.workspace_shapes(
|
||||||
a1, a1q, M, N, K, top_k, global_num_experts))
|
a1, a1q, M, N, K, top_k, global_num_experts,
|
||||||
|
local_num_experts))
|
||||||
# Use the CHUNK_SIZE to get the workspace shapes.
|
# Use the CHUNK_SIZE to get the workspace shapes.
|
||||||
workspace13_shape, workspace2_shape, _, workspace_dtype = (
|
workspace13_shape, workspace2_shape, _, workspace_dtype = (
|
||||||
self.fused_experts.workspace_shapes(
|
self.fused_experts.workspace_shapes(
|
||||||
a1, a1q, CHUNK_SIZE, N, K, top_k, global_num_experts))
|
a1, a1q, CHUNK_SIZE, N, K, top_k, global_num_experts,
|
||||||
|
local_num_experts))
|
||||||
|
|
||||||
# We can reuse the memory between cache1 and cache3 because by the
|
# We can reuse the memory between cache1 and cache3 because by the
|
||||||
# time we need cache3, we're done with cache1.
|
# time we need cache3, we're done with cache1.
|
||||||
|
|||||||
@ -159,6 +159,12 @@ def moe_align_block_size(
|
|||||||
Aligns the token distribution across experts to be compatible with block
|
Aligns the token distribution across experts to be compatible with block
|
||||||
size for matrix multiplication.
|
size for matrix multiplication.
|
||||||
|
|
||||||
|
Note: In the case of expert_parallel, moe_align_block_size initially
|
||||||
|
considers all experts as valid and aligns all tokens appropriately.
|
||||||
|
Before the function returns it marks the experts_ids that are not in
|
||||||
|
the current GPU rank as -1 so the MoE matmuls could skip those blocks.
|
||||||
|
This requires the num_experts input arg to be the num global experts.
|
||||||
|
|
||||||
Parameters:
|
Parameters:
|
||||||
- topk_ids: A tensor of shape [total_tokens, top_k] representing the
|
- topk_ids: A tensor of shape [total_tokens, top_k] representing the
|
||||||
top-k expert indices for each token.
|
top-k expert indices for each token.
|
||||||
|
|||||||
@ -48,7 +48,8 @@ class TritonOrDeepGemmExperts(mk.FusedMoEPermuteExpertsUnpermute):
|
|||||||
N: int,
|
N: int,
|
||||||
K: int,
|
K: int,
|
||||||
topk: int,
|
topk: int,
|
||||||
num_experts: int,
|
global_num_experts: int,
|
||||||
|
local_num_experts: int,
|
||||||
) -> tuple[tuple[int, ...], tuple[int, ...], tuple[int, ...], torch.dtype]:
|
) -> tuple[tuple[int, ...], tuple[int, ...], tuple[int, ...], torch.dtype]:
|
||||||
# Note: the deep gemm workspaces are strictly larger than the triton
|
# Note: the deep gemm workspaces are strictly larger than the triton
|
||||||
# workspaces so we can be pessimistic here and allocate for DeepGemm
|
# workspaces so we can be pessimistic here and allocate for DeepGemm
|
||||||
@ -56,10 +57,11 @@ class TritonOrDeepGemmExperts(mk.FusedMoEPermuteExpertsUnpermute):
|
|||||||
if self.allow_deep_gemm and _valid_deep_gemm_shape(M, N, K):
|
if self.allow_deep_gemm and _valid_deep_gemm_shape(M, N, K):
|
||||||
assert self.deep_gemm_expert is not None
|
assert self.deep_gemm_expert is not None
|
||||||
return self.deep_gemm_expert.workspace_shapes(
|
return self.deep_gemm_expert.workspace_shapes(
|
||||||
a, aq, M, N, K, topk, num_experts)
|
a, aq, M, N, K, topk, global_num_experts, local_num_experts)
|
||||||
else:
|
else:
|
||||||
return self.triton_expert.workspace_shapes(a, aq, M, N, K, topk,
|
return self.triton_expert.workspace_shapes(a, aq, M, N, K, topk,
|
||||||
num_experts)
|
global_num_experts,
|
||||||
|
local_num_experts)
|
||||||
|
|
||||||
def apply(
|
def apply(
|
||||||
self,
|
self,
|
||||||
|
|||||||
Loading…
x
Reference in New Issue
Block a user