diff --git a/vllm/model_executor/layers/fused_moe/batched_deep_gemm_moe.py b/vllm/model_executor/layers/fused_moe/batched_deep_gemm_moe.py index 30b74165657eb..5492399efdf86 100644 --- a/vllm/model_executor/layers/fused_moe/batched_deep_gemm_moe.py +++ b/vllm/model_executor/layers/fused_moe/batched_deep_gemm_moe.py @@ -47,15 +47,21 @@ class BatchedDeepGemmExperts(mk.FusedMoEPermuteExpertsUnpermute): N: int, K: int, topk: int, - num_experts: int, + global_num_experts: int, + local_num_experts: int, ) -> tuple[tuple[int, ...], tuple[int, ...], tuple[int, ...], torch.dtype]: assert a.dim() == 2 - num_dp = self.world_size // self.dp_size + # FIXME (varun): We should be able to dispatch only from the leader + # DP ranks in the case of TP > 1. At the moment, all the Ranks + # end up sending their tokens. This needs to be fixed. + num_dispatchers = self.world_size + num_experts = local_num_experts max_num_tokens = a.size( 0) if self.max_num_tokens is None else self.max_num_tokens - workspace13 = (num_experts, max_num_tokens * num_dp, max(K, N)) - workspace2 = (num_experts, max_num_tokens * num_dp, (N // 2)) - output = (num_experts, max_num_tokens * num_dp, K) + workspace13 = (num_experts, max_num_tokens * num_dispatchers, + max(K, N)) + workspace2 = (num_experts, max_num_tokens * num_dispatchers, (N // 2)) + output = (num_experts, max_num_tokens * num_dispatchers, K) return (workspace13, workspace2, output, a.dtype) def apply( @@ -84,9 +90,6 @@ class BatchedDeepGemmExperts(mk.FusedMoEPermuteExpertsUnpermute): a1q = hidden_states _, N, K = w1.size() - if global_num_experts == -1: - global_num_experts = w1.size(0) - assert w2.size(1) == K E, max_num_tokens, N, K, top_k_num = mk._moe_problem_size( diff --git a/vllm/model_executor/layers/fused_moe/batched_triton_or_deep_gemm_moe.py b/vllm/model_executor/layers/fused_moe/batched_triton_or_deep_gemm_moe.py index d0ce59ba1e62f..822cda8205bfe 100644 --- a/vllm/model_executor/layers/fused_moe/batched_triton_or_deep_gemm_moe.py +++ b/vllm/model_executor/layers/fused_moe/batched_triton_or_deep_gemm_moe.py @@ -81,18 +81,19 @@ class BatchedTritonOrDeepGemmExperts(mk.FusedMoEPermuteExpertsUnpermute): N: int, K: int, topk: int, - num_experts: int, + global_num_experts: int, + local_num_experts: int, ) -> tuple[tuple[int, ...], tuple[int, ...], tuple[int, ...], torch.dtype]: # Note: the deep gemm workspaces are strictly larger than the triton # workspaces so we can be pessimistic here and allocate for DeepGemm # even if we fall back to triton later, e.g. if expert maps are set. if self.allow_deep_gemm and self.batched_deep_gemm_experts is not None: return self.batched_deep_gemm_experts.workspace_shapes( - a, aq, M, N, K, topk, num_experts) + a, aq, M, N, K, topk, global_num_experts, local_num_experts) else: assert self.batched_triton_experts is not None return self.batched_triton_experts.workspace_shapes( - a, aq, M, N, K, topk, num_experts) + a, aq, M, N, K, topk, global_num_experts, local_num_experts) def apply( self, diff --git a/vllm/model_executor/layers/fused_moe/cutlass_moe.py b/vllm/model_executor/layers/fused_moe/cutlass_moe.py index f380cb77c7e83..3f9ceac8b6e36 100644 --- a/vllm/model_executor/layers/fused_moe/cutlass_moe.py +++ b/vllm/model_executor/layers/fused_moe/cutlass_moe.py @@ -230,7 +230,8 @@ class CutlassExpertsFp8(mk.FusedMoEPermuteExpertsUnpermute): N: int, K: int, topk: int, - num_experts: int, + global_num_experts: int, + local_num_experts: int, ) -> tuple[tuple[int, ...], tuple[int, ...], tuple[int, ...], torch.dtype]: workspace1: tuple[int, ...] = () workspace2: tuple[int, ...] = () diff --git a/vllm/model_executor/layers/fused_moe/deep_gemm_moe.py b/vllm/model_executor/layers/fused_moe/deep_gemm_moe.py index 595e8c99514d7..b4473b907381a 100644 --- a/vllm/model_executor/layers/fused_moe/deep_gemm_moe.py +++ b/vllm/model_executor/layers/fused_moe/deep_gemm_moe.py @@ -74,15 +74,12 @@ class DeepGemmExperts(mk.FusedMoEPermuteExpertsUnpermute): return True def workspace_shapes( - self, - a: torch.Tensor, - aq: torch.Tensor, - M: int, - N: int, - K: int, - topk: int, - num_experts: int, + self, a: torch.Tensor, aq: torch.Tensor, M: int, N: int, K: int, + topk: int, global_num_experts: int, local_num_experts: int ) -> tuple[tuple[int, ...], tuple[int, ...], tuple[int, ...], torch.dtype]: + # We use global_num_experts due to how moe_align_block_size handles + # expert_maps. + num_experts = global_num_experts block_m = self.block_shape[0] M_sum = (M * topk) + num_experts * (block_m - 1) M_sum = round_up(M_sum, block_m) diff --git a/vllm/model_executor/layers/fused_moe/fused_batched_moe.py b/vllm/model_executor/layers/fused_moe/fused_batched_moe.py index fb66e96c7946e..3bbae4e57ba34 100644 --- a/vllm/model_executor/layers/fused_moe/fused_batched_moe.py +++ b/vllm/model_executor/layers/fused_moe/fused_batched_moe.py @@ -521,10 +521,12 @@ class BatchedExperts(mk.FusedMoEPermuteExpertsUnpermute): N: int, K: int, topk: int, - num_experts: int, + global_num_experts: int, + local_num_experts: int, ) -> tuple[tuple[int, ...], tuple[int, ...], tuple[int, ...], torch.dtype]: assert a.dim() == 2 - num_dp = self.world_size // self.dp_size + num_dp = self.dp_size + num_experts = local_num_experts workspace13 = (num_experts, self.max_num_tokens * num_dp, K) workspace2 = (self.max_num_tokens * num_dp, N) return (workspace13, workspace2, workspace13, a.dtype) @@ -624,10 +626,12 @@ class BatchedTritonExperts(mk.FusedMoEPermuteExpertsUnpermute): N: int, K: int, topk: int, - num_experts: int, + global_num_experts: int, + local_num_experts: int, ) -> tuple[tuple[int, ...], tuple[int, ...], tuple[int, ...], torch.dtype]: assert a.dim() == 2 num_dp = self.world_size // self.dp_size + num_experts = local_num_experts max_num_tokens = a.size( 0) if self.max_num_tokens is None else self.max_num_tokens workspace13 = (num_experts, max_num_tokens * num_dp, max(K, N)) diff --git a/vllm/model_executor/layers/fused_moe/fused_moe.py b/vllm/model_executor/layers/fused_moe/fused_moe.py index d9b1ba132671a..437e80696ac65 100644 --- a/vllm/model_executor/layers/fused_moe/fused_moe.py +++ b/vllm/model_executor/layers/fused_moe/fused_moe.py @@ -1553,7 +1553,8 @@ class TritonExperts(mk.FusedMoEPermuteExpertsUnpermute): N: int, K: int, topk: int, - num_experts: int, + global_num_experts: int, + local_num_experts: int, ) -> tuple[tuple[int, ...], tuple[int, ...], tuple[int, ...], torch.dtype]: workspace1 = (M, topk, max(N * 2, K)) workspace2 = (M, topk, N) diff --git a/vllm/model_executor/layers/fused_moe/modular_kernel.py b/vllm/model_executor/layers/fused_moe/modular_kernel.py index 9ef6a126680cf..9409b59982d90 100644 --- a/vllm/model_executor/layers/fused_moe/modular_kernel.py +++ b/vllm/model_executor/layers/fused_moe/modular_kernel.py @@ -194,7 +194,8 @@ class FusedMoEPermuteExpertsUnpermute(ABC): N: int, K: int, topk: int, - num_experts: int, + global_num_experts: int, + local_num_experts: int, ) -> tuple[tuple[int, ...], tuple[int, ...], tuple[int, ...], torch.dtype]: """ Compute the shapes for the temporary and final outputs of the two gemms @@ -372,8 +373,9 @@ class FusedMoEModularKernel(torch.nn.Module): a1 = hidden_states output = a1 if inplace else torch.zeros_like(a1) + local_num_experts = w1.size(0) if global_num_experts == -1: - global_num_experts = w1.size(0) + global_num_experts = local_num_experts (a1q, a1q_scale, expert_num_tokens, _expert_topk_ids, _expert_topk_weights) = self.prepare_finalize.prepare( @@ -408,16 +410,19 @@ class FusedMoEModularKernel(torch.nn.Module): if num_chunks == 1: (workspace13_shape, workspace2_shape, fused_out_shape, workspace_dtype) = self.fused_experts.workspace_shapes( - a1, a1q, M, N, K, top_k, global_num_experts) + a1, a1q, M, N, K, top_k, global_num_experts, + local_num_experts) else: # Use the full M to get the final output shape. _, _, fused_out_shape, _ = ( self.fused_experts.workspace_shapes( - a1, a1q, M, N, K, top_k, global_num_experts)) + a1, a1q, M, N, K, top_k, global_num_experts, + local_num_experts)) # Use the CHUNK_SIZE to get the workspace shapes. workspace13_shape, workspace2_shape, _, workspace_dtype = ( self.fused_experts.workspace_shapes( - a1, a1q, CHUNK_SIZE, N, K, top_k, global_num_experts)) + a1, a1q, CHUNK_SIZE, N, K, top_k, global_num_experts, + local_num_experts)) # We can reuse the memory between cache1 and cache3 because by the # time we need cache3, we're done with cache1. diff --git a/vllm/model_executor/layers/fused_moe/moe_align_block_size.py b/vllm/model_executor/layers/fused_moe/moe_align_block_size.py index 98e175b12ed45..9d990959e01fa 100644 --- a/vllm/model_executor/layers/fused_moe/moe_align_block_size.py +++ b/vllm/model_executor/layers/fused_moe/moe_align_block_size.py @@ -159,6 +159,12 @@ def moe_align_block_size( Aligns the token distribution across experts to be compatible with block size for matrix multiplication. + Note: In the case of expert_parallel, moe_align_block_size initially + considers all experts as valid and aligns all tokens appropriately. + Before the function returns it marks the experts_ids that are not in + the current GPU rank as -1 so the MoE matmuls could skip those blocks. + This requires the num_experts input arg to be the num global experts. + Parameters: - topk_ids: A tensor of shape [total_tokens, top_k] representing the top-k expert indices for each token. diff --git a/vllm/model_executor/layers/fused_moe/triton_deep_gemm_moe.py b/vllm/model_executor/layers/fused_moe/triton_deep_gemm_moe.py index d4233c23f5312..4bbfea446e291 100644 --- a/vllm/model_executor/layers/fused_moe/triton_deep_gemm_moe.py +++ b/vllm/model_executor/layers/fused_moe/triton_deep_gemm_moe.py @@ -48,7 +48,8 @@ class TritonOrDeepGemmExperts(mk.FusedMoEPermuteExpertsUnpermute): N: int, K: int, topk: int, - num_experts: int, + global_num_experts: int, + local_num_experts: int, ) -> tuple[tuple[int, ...], tuple[int, ...], tuple[int, ...], torch.dtype]: # Note: the deep gemm workspaces are strictly larger than the triton # workspaces so we can be pessimistic here and allocate for DeepGemm @@ -56,10 +57,11 @@ class TritonOrDeepGemmExperts(mk.FusedMoEPermuteExpertsUnpermute): if self.allow_deep_gemm and _valid_deep_gemm_shape(M, N, K): assert self.deep_gemm_expert is not None return self.deep_gemm_expert.workspace_shapes( - a, aq, M, N, K, topk, num_experts) + a, aq, M, N, K, topk, global_num_experts, local_num_experts) else: return self.triton_expert.workspace_shapes(a, aq, M, N, K, topk, - num_experts) + global_num_experts, + local_num_experts) def apply( self,