diff --git a/vllm/model_executor/layers/fused_moe/batched_deep_gemm_moe.py b/vllm/model_executor/layers/fused_moe/batched_deep_gemm_moe.py index 70a580b9c4c70..0b39432921522 100644 --- a/vllm/model_executor/layers/fused_moe/batched_deep_gemm_moe.py +++ b/vllm/model_executor/layers/fused_moe/batched_deep_gemm_moe.py @@ -260,6 +260,7 @@ class BatchedDeepGemmExperts(mk.FusedMoEPermuteExpertsUnpermute): hidden_states: torch.Tensor, w1: torch.Tensor, w2: torch.Tensor, + topk_weights: torch.Tensor, topk_ids: torch.Tensor, activation: str, global_num_experts: int, @@ -273,6 +274,7 @@ class BatchedDeepGemmExperts(mk.FusedMoEPermuteExpertsUnpermute): workspace13: torch.Tensor, workspace2: torch.Tensor, expert_tokens_meta: Optional[mk.ExpertTokensMetadata], + apply_router_weight_on_input: bool, ): assert expert_tokens_meta is not None expert_num_tokens = expert_tokens_meta.expert_num_tokens diff --git a/vllm/model_executor/layers/fused_moe/batched_triton_or_deep_gemm_moe.py b/vllm/model_executor/layers/fused_moe/batched_triton_or_deep_gemm_moe.py index 41faced58f1a5..12df9bb34d258 100644 --- a/vllm/model_executor/layers/fused_moe/batched_triton_or_deep_gemm_moe.py +++ b/vllm/model_executor/layers/fused_moe/batched_triton_or_deep_gemm_moe.py @@ -129,30 +129,22 @@ class BatchedTritonOrDeepGemmExperts(mk.FusedMoEPermuteExpertsUnpermute): return self.batched_triton_experts.workspace_shapes( a, aq, M, N, K, topk, global_num_experts, local_num_experts) - def apply( - self, - output: torch.Tensor, - hidden_states: torch.Tensor, - w1: torch.Tensor, - w2: torch.Tensor, - topk_ids: torch.Tensor, - activation: str, - global_num_experts: int, - expert_map: Optional[torch.Tensor], - w1_scale: Optional[torch.Tensor], - w2_scale: Optional[torch.Tensor], - w1_zp: Optional[torch.Tensor], - w2_zp: Optional[torch.Tensor], - a1q_scale: Optional[torch.Tensor], - a2_scale: Optional[torch.Tensor], - workspace13: torch.Tensor, - workspace2: torch.Tensor, - expert_tokens_meta: Optional[mk.ExpertTokensMetadata], - ): + def apply(self, output: torch.Tensor, hidden_states: torch.Tensor, + w1: torch.Tensor, w2: torch.Tensor, topk_weights: torch.Tensor, + topk_ids: torch.Tensor, activation: str, global_num_experts: int, + expert_map: Optional[torch.Tensor], + w1_scale: Optional[torch.Tensor], + w2_scale: Optional[torch.Tensor], w1_zp: Optional[torch.Tensor], + w2_zp: Optional[torch.Tensor], a1q_scale: Optional[torch.Tensor], + a2_scale: Optional[torch.Tensor], workspace13: torch.Tensor, + workspace2: torch.Tensor, + expert_tokens_meta: Optional[mk.ExpertTokensMetadata], + apply_router_weight_on_input: bool): experts = (self.batched_deep_gemm_experts if self.allow_deep_gemm else self.batched_triton_experts) assert experts is not None - experts.apply(output, hidden_states, w1, w2, topk_ids, activation, - global_num_experts, expert_map, w1_scale, w2_scale, - w1_zp, w2_zp, a1q_scale, a2_scale, workspace13, - workspace2, expert_tokens_meta) + experts.apply(output, hidden_states, w1, w2, topk_weights, topk_ids, + activation, global_num_experts, expert_map, w1_scale, + w2_scale, w1_zp, w2_zp, a1q_scale, a2_scale, workspace13, + workspace2, expert_tokens_meta, + apply_router_weight_on_input) diff --git a/vllm/model_executor/layers/fused_moe/cutlass_moe.py b/vllm/model_executor/layers/fused_moe/cutlass_moe.py index d6a30e3426950..e479f1b404449 100644 --- a/vllm/model_executor/layers/fused_moe/cutlass_moe.py +++ b/vllm/model_executor/layers/fused_moe/cutlass_moe.py @@ -291,26 +291,17 @@ class CutlassExpertsFp8(mk.FusedMoEPermuteExpertsUnpermute): return (workspace1, workspace2, output, self.out_dtype if self.out_dtype is not None else a.dtype) - def apply( - self, - output: torch.Tensor, - hidden_states: torch.Tensor, - w1: torch.Tensor, - w2: torch.Tensor, - topk_ids: torch.Tensor, - activation: str, - global_num_experts: int, - expert_map: Optional[torch.Tensor], - w1_scale: Optional[torch.Tensor], - w2_scale: Optional[torch.Tensor], - w1_zp: Optional[torch.Tensor], - w2_zp: Optional[torch.Tensor], - a1q_scale: Optional[torch.Tensor], - a2_scale: Optional[torch.Tensor], - workspace13: torch.Tensor, - workspace2: torch.Tensor, - expert_tokens_meta: Optional[mk.ExpertTokensMetadata], - ): + def apply(self, output: torch.Tensor, hidden_states: torch.Tensor, + w1: torch.Tensor, w2: torch.Tensor, topk_weights: torch.Tensor, + topk_ids: torch.Tensor, activation: str, global_num_experts: int, + expert_map: Optional[torch.Tensor], + w1_scale: Optional[torch.Tensor], + w2_scale: Optional[torch.Tensor], w1_zp: Optional[torch.Tensor], + w2_zp: Optional[torch.Tensor], a1q_scale: Optional[torch.Tensor], + a2_scale: Optional[torch.Tensor], workspace13: torch.Tensor, + workspace2: torch.Tensor, + expert_tokens_meta: Optional[mk.ExpertTokensMetadata], + apply_router_weight_on_input: bool): assert w1_zp is None, "w1_zp is not supported in CUTLASS MoE" assert w2_zp is None, "w2_zp is not supported in CUTLASS MoE" diff --git a/vllm/model_executor/layers/fused_moe/deep_gemm_moe.py b/vllm/model_executor/layers/fused_moe/deep_gemm_moe.py index b1107a1f47931..cc5e7cf57147a 100644 --- a/vllm/model_executor/layers/fused_moe/deep_gemm_moe.py +++ b/vllm/model_executor/layers/fused_moe/deep_gemm_moe.py @@ -13,7 +13,7 @@ from vllm.model_executor.layers.fused_moe.moe_permute_unpermute import ( from vllm.model_executor.layers.fused_moe.prepare_finalize import ( MoEPrepareAndFinalizeNoEP) from vllm.model_executor.layers.fused_moe.topk_weight_and_reduce import ( - TopKWeightAndReduceDelegate) + TopKWeightAndReduceContiguous, TopKWeightAndReduceNoOP) from vllm.model_executor.layers.fused_moe.utils import _resize_cache from vllm.model_executor.layers.quantization.utils.fp8_utils import ( per_token_group_quant_fp8) @@ -90,8 +90,7 @@ class DeepGemmExperts(mk.FusedMoEPermuteExpertsUnpermute): return True def finalize_weight_and_reduce_impl(self) -> mk.TopKWeightAndReduce: - # Let PrepareAndFinalize::finalize() decide the impl. - return TopKWeightAndReduceDelegate() + return TopKWeightAndReduceNoOP() def workspace_shapes( self, a: torch.Tensor, aq: torch.Tensor, M: int, N: int, K: int, @@ -104,9 +103,9 @@ class DeepGemmExperts(mk.FusedMoEPermuteExpertsUnpermute): block_m = self.block_shape[0] M_sum = (M * topk) + num_experts * (block_m - 1) M_sum = round_up(M_sum, block_m) - workspace1 = (M_sum, max(N * 2, K)) + workspace1 = (M_sum, max(N // 2, K)) workspace2 = (M_sum, max(N, K)) - output = (M, topk, K) + output = (M, K) return (workspace1, workspace2, output, a.dtype) def apply( @@ -115,6 +114,7 @@ class DeepGemmExperts(mk.FusedMoEPermuteExpertsUnpermute): hidden_states: torch.Tensor, w1: torch.Tensor, w2: torch.Tensor, + topk_weights: torch.Tensor, topk_ids: torch.Tensor, activation: str, global_num_experts: int, @@ -128,11 +128,14 @@ class DeepGemmExperts(mk.FusedMoEPermuteExpertsUnpermute): workspace13: torch.Tensor, workspace2: torch.Tensor, expert_tokens_meta: Optional[mk.ExpertTokensMetadata], + apply_router_weight_on_input: bool, ): assert self.block_shape is not None a1q = hidden_states _, N, K = w1.size() + M, _ = output.size() + num_topk = topk_ids.size(1) if global_num_experts == -1: global_num_experts = w1.size(0) @@ -159,11 +162,12 @@ class DeepGemmExperts(mk.FusedMoEPermuteExpertsUnpermute): # Note: M_sum is different than the pre-permuted shape of a1q. M_sum = a1q.size(0) - mm1_out = _resize_cache(workspace13, (M_sum, N)) - act_out = _resize_cache(workspace2, (M_sum, N // 2)) - quant_out = _resize_cache(workspace13.view(dtype=torch.float8_e4m3fn), + mm1_out = _resize_cache(workspace2, (M_sum, N)) + act_out = _resize_cache(workspace13, (M_sum, N // 2)) + quant_out = _resize_cache(workspace2.view(dtype=torch.float8_e4m3fn), (M_sum, N // 2)) - mm2_out = _resize_cache(workspace2, (M_sum, K)) + mm2_out = _resize_cache(workspace13, (M_sum, K)) + perm_out = _resize_cache(workspace2, (M * num_topk, K)) m_grouped_fp8_gemm_nt_contiguous((a1q, a1q_scale), (w1, w1_scale), mm1_out, expert_ids) @@ -179,7 +183,14 @@ class DeepGemmExperts(mk.FusedMoEPermuteExpertsUnpermute): m_grouped_fp8_gemm_nt_contiguous((a2q, a2q_scale), (w2, w2_scale), mm2_out, expert_ids) - torch.index_select(mm2_out, 0, inv_perm, out=output.view((-1, K))) + torch.index_select(mm2_out, 0, inv_perm, out=perm_out) + + TopKWeightAndReduceContiguous().apply( + output=output, + fused_expert_output=perm_out, + topk_weights=topk_weights, + topk_ids=topk_ids, + apply_router_weight_on_input=apply_router_weight_on_input) def deep_gemm_moe_fp8( diff --git a/vllm/model_executor/layers/fused_moe/fused_batched_moe.py b/vllm/model_executor/layers/fused_moe/fused_batched_moe.py index 61247e93091f1..b311ef1ac1cb2 100644 --- a/vllm/model_executor/layers/fused_moe/fused_batched_moe.py +++ b/vllm/model_executor/layers/fused_moe/fused_batched_moe.py @@ -696,15 +696,16 @@ class NaiveBatchedExperts(mk.FusedMoEPermuteExpertsUnpermute): return t.to(f32) * group_broadcast(scale, t.shape) def apply(self, output: torch.Tensor, hidden_states: torch.Tensor, - w1: torch.Tensor, w2: torch.Tensor, topk_ids: torch.Tensor, - activation: str, global_num_experts: int, + w1: torch.Tensor, w2: torch.Tensor, topk_weights: torch.Tensor, + topk_ids: torch.Tensor, activation: str, global_num_experts: int, expert_map: Optional[torch.Tensor], w1_scale: Optional[torch.Tensor], w2_scale: Optional[torch.Tensor], w1_zp: Optional[torch.Tensor], w2_zp: Optional[torch.Tensor], a1q_scale: Optional[torch.Tensor], a2_scale: Optional[torch.Tensor], workspace13: torch.Tensor, workspace2: torch.Tensor, - expert_tokens_meta: Optional[mk.ExpertTokensMetadata]): + expert_tokens_meta: Optional[mk.ExpertTokensMetadata], + apply_router_weight_on_input: bool): assert hidden_states.dim() == 3 assert expert_tokens_meta is not None expert_num_tokens = expert_tokens_meta.expert_num_tokens @@ -899,15 +900,16 @@ class BatchedTritonExperts(mk.FusedMoEPermuteExpertsUnpermute): return (workspace13, workspace2, output, a.dtype) def apply(self, output: torch.Tensor, hidden_states: torch.Tensor, - w1: torch.Tensor, w2: torch.Tensor, topk_ids: torch.Tensor, - activation: str, global_num_experts: int, + w1: torch.Tensor, w2: torch.Tensor, topk_weights: torch.Tensor, + topk_ids: torch.Tensor, activation: str, global_num_experts: int, expert_map: Optional[torch.Tensor], w1_scale: Optional[torch.Tensor], w2_scale: Optional[torch.Tensor], w1_zp: Optional[torch.Tensor], w2_zp: Optional[torch.Tensor], a1q_scale: Optional[torch.Tensor], a2_scale: Optional[torch.Tensor], workspace13: torch.Tensor, workspace2: torch.Tensor, - expert_tokens_meta: Optional[mk.ExpertTokensMetadata]): + expert_tokens_meta: Optional[mk.ExpertTokensMetadata], + apply_router_weight_on_input: bool): # Check constraints. if self.use_int4_w4a16: assert hidden_states.size(-1) // 2 == w1.size(2), ( diff --git a/vllm/model_executor/layers/fused_moe/fused_moe.py b/vllm/model_executor/layers/fused_moe/fused_moe.py index 6a9767fc6f3fd..f0bffc7dae276 100644 --- a/vllm/model_executor/layers/fused_moe/fused_moe.py +++ b/vllm/model_executor/layers/fused_moe/fused_moe.py @@ -26,7 +26,7 @@ from vllm.model_executor.layers.fused_moe.moe_align_block_size import ( from vllm.model_executor.layers.fused_moe.prepare_finalize import ( MoEPrepareAndFinalizeNoEP) from vllm.model_executor.layers.fused_moe.topk_weight_and_reduce import ( - TopKWeightAndReduceDelegate) + TopKWeightAndReduceNoOP) from vllm.model_executor.layers.fused_moe.utils import ( _resize_cache, moe_kernel_quantize_input) from vllm.model_executor.layers.quantization.utils.mxfp4_utils import ( @@ -1606,8 +1606,7 @@ class TritonExperts(mk.FusedMoEPermuteExpertsUnpermute): return True def finalize_weight_and_reduce_impl(self) -> mk.TopKWeightAndReduce: - # Let PrepareAndFinalize::finalize() decide the impl. - return TopKWeightAndReduceDelegate() + return TopKWeightAndReduceNoOP() def workspace_shapes( self, @@ -1620,9 +1619,9 @@ class TritonExperts(mk.FusedMoEPermuteExpertsUnpermute): global_num_experts: int, local_num_experts: int, ) -> tuple[tuple[int, ...], tuple[int, ...], tuple[int, ...], torch.dtype]: - workspace1 = (M, topk, max(N * 2, K)) - workspace2 = (M, topk, N) - output = (M, topk, K) + workspace1 = (M, topk, max(N // 2, K)) + workspace2 = (M, topk, max(N, K)) + output = (M, K) return (workspace1, workspace2, output, a.dtype) def apply( @@ -1631,6 +1630,7 @@ class TritonExperts(mk.FusedMoEPermuteExpertsUnpermute): hidden_states: torch.Tensor, w1: torch.Tensor, w2: torch.Tensor, + topk_weights: torch.Tensor, topk_ids: torch.Tensor, activation: str, global_num_experts: int, @@ -1644,6 +1644,7 @@ class TritonExperts(mk.FusedMoEPermuteExpertsUnpermute): workspace13: torch.Tensor, workspace2: torch.Tensor, expert_tokens_meta: Optional[mk.ExpertTokensMetadata], + apply_router_weight_on_input: bool, ): # Check constraints. if self.use_int4_w4a16: @@ -1696,37 +1697,39 @@ class TritonExperts(mk.FusedMoEPermuteExpertsUnpermute): raise ValueError( f"Unsupported compute_type: {hidden_states.dtype}") - # We can reuse the memory between these because by the time we need - # cache3, we're done with cache1 - intermediate_cache1 = _resize_cache(workspace13, + # Note that the output tensor might be in workspace1 + intermediate_cache1 = _resize_cache(workspace2, (num_tokens, top_k_num, N)) - intermediate_cache2 = _resize_cache(workspace2, + intermediate_cache2 = _resize_cache(workspace13, (num_tokens * top_k_num, N // 2)) + intermediate_cache3 = _resize_cache(workspace2, + (num_tokens, top_k_num, K)) sorted_token_ids, expert_ids, num_tokens_post_padded = ( moe_align_block_size(topk_ids, config['BLOCK_SIZE_M'], global_num_experts, expert_map)) - invoke_fused_moe_kernel(hidden_states, - w1, - intermediate_cache1, - a1q_scale, - w1_scale, - w1_zp, - None, - sorted_token_ids, - expert_ids, - num_tokens_post_padded, - False, - top_k_num, - config, - compute_type=compute_type, - use_fp8_w8a8=self.use_fp8_w8a8, - use_int8_w8a8=self.use_int8_w8a8, - use_int8_w8a16=self.use_int8_w8a16, - use_int4_w4a16=self.use_int4_w4a16, - per_channel_quant=self.per_act_token_quant, - block_shape=self.block_shape) + invoke_fused_moe_kernel( + hidden_states, + w1, + intermediate_cache1, + a1q_scale, + w1_scale, + w1_zp, + None, # topk_weights + sorted_token_ids, + expert_ids, + num_tokens_post_padded, + False, # mul_routed_weights + top_k_num, + config, + compute_type=compute_type, + use_fp8_w8a8=self.use_fp8_w8a8, + use_int8_w8a8=self.use_int8_w8a8, + use_int8_w8a16=self.use_int8_w8a16, + use_int4_w4a16=self.use_int4_w4a16, + per_channel_quant=self.per_act_token_quant, + block_shape=self.block_shape) self.activation(activation, intermediate_cache2, intermediate_cache1.view(-1, N)) @@ -1739,15 +1742,15 @@ class TritonExperts(mk.FusedMoEPermuteExpertsUnpermute): invoke_fused_moe_kernel(qintermediate_cache2, w2, - output, + intermediate_cache3, a2q_scale, w2_scale, w2_zp, - None, + topk_weights, sorted_token_ids, expert_ids, num_tokens_post_padded, - False, + not apply_router_weight_on_input, 1, config, compute_type=compute_type, @@ -1758,6 +1761,8 @@ class TritonExperts(mk.FusedMoEPermuteExpertsUnpermute): per_channel_quant=self.per_act_token_quant, block_shape=self.block_shape) + ops.moe_sum(intermediate_cache3, output) + def modular_triton_fused_moe( use_fp8_w8a8: bool, diff --git a/vllm/model_executor/layers/fused_moe/modular_kernel.py b/vllm/model_executor/layers/fused_moe/modular_kernel.py index d0d8c7d6f41e9..028eee2417864 100644 --- a/vllm/model_executor/layers/fused_moe/modular_kernel.py +++ b/vllm/model_executor/layers/fused_moe/modular_kernel.py @@ -360,6 +360,7 @@ class FusedMoEPermuteExpertsUnpermute(ABC): hidden_states: torch.Tensor, w1: torch.Tensor, w2: torch.Tensor, + topk_weights: torch.Tensor, topk_ids: torch.Tensor, activation: str, global_num_experts: int, @@ -373,6 +374,7 @@ class FusedMoEPermuteExpertsUnpermute(ABC): workspace13: torch.Tensor, workspace2: torch.Tensor, expert_tokens_meta: Optional[ExpertTokensMetadata], + apply_router_weight_on_input: bool, ): """ This function computes the intermediate result of a Mixture of Experts @@ -384,6 +386,8 @@ class FusedMoEPermuteExpertsUnpermute(ABC): layer. - w1 (torch.Tensor): The first set of expert weights. - w2 (torch.Tensor): The second set of expert weights. + - topk_weights: A map of row to expert weights. Some implementations + choose to do weight application. - topk_ids (torch.Tensor): A map of row to expert id. - activation (str): The activation function to apply after the first MoE layer. @@ -409,6 +413,9 @@ class FusedMoEPermuteExpertsUnpermute(ABC): ExpertTokensMetadata object containing gpu/cpu tensors as big as the number of local experts with the information about the number of tokens assigned to each local expert. + - apply_router_weight_on_input: True if router weights are already + applied on the input. This is relevant if the implementation + chooses to do weight application. """ raise NotImplementedError @@ -452,17 +459,21 @@ class FusedMoEModularKernel(torch.nn.Module): f"{fused_experts.__class__.__name__}." f"{fused_experts.activation_formats[0]}") - def _do_fused_experts( - self, fused_out: Optional[torch.Tensor], a1: torch.Tensor, - a1q: torch.Tensor, w1: torch.Tensor, w2: torch.Tensor, - topk_ids: torch.Tensor, activation: str, global_num_experts: int, - local_num_experts: int, expert_map: Optional[torch.Tensor], - w1_scale: Optional[torch.Tensor], w2_scale: Optional[torch.Tensor], - w1_zp: Optional[torch.Tensor], w2_zp: Optional[torch.Tensor], - a1q_scale: Optional[torch.Tensor], - a2_scale: Optional[torch.Tensor], - expert_tokens_meta: Optional[ExpertTokensMetadata] - ) -> torch.Tensor: + def _do_fused_experts(self, fused_out: Optional[torch.Tensor], + a1: torch.Tensor, a1q: torch.Tensor, + w1: torch.Tensor, w2: torch.Tensor, + topk_weights: torch.Tensor, topk_ids: torch.Tensor, + activation: str, global_num_experts: int, + local_num_experts: int, + expert_map: Optional[torch.Tensor], + w1_scale: Optional[torch.Tensor], + w2_scale: Optional[torch.Tensor], + w1_zp: Optional[torch.Tensor], + w2_zp: Optional[torch.Tensor], + a1q_scale: Optional[torch.Tensor], + a2_scale: Optional[torch.Tensor], + expert_tokens_meta: Optional[ExpertTokensMetadata], + apply_router_weight_on_input: bool) -> torch.Tensor: _, M, N, K, top_k = _moe_problem_size(a1q, w1, w2, topk_ids) @@ -485,36 +496,49 @@ class FusedMoEModularKernel(torch.nn.Module): # reuse workspace13 for the output fused_out = _resize_cache(workspace13, fused_out_shape) - self.fused_experts.apply(fused_out, - a1q, - w1, - w2, - topk_ids=topk_ids, - activation=activation, - global_num_experts=global_num_experts, - expert_map=expert_map, - w1_scale=w1_scale, - w2_scale=w2_scale, - w1_zp=w1_zp, - w2_zp=w2_zp, - a1q_scale=a1q_scale, - a2_scale=a2_scale, - workspace13=workspace13, - workspace2=workspace2, - expert_tokens_meta=expert_tokens_meta) + self.fused_experts.apply( + fused_out, + a1q, + w1, + w2, + topk_weights=topk_weights, + topk_ids=topk_ids, + activation=activation, + global_num_experts=global_num_experts, + expert_map=expert_map, + w1_scale=w1_scale, + w2_scale=w2_scale, + w1_zp=w1_zp, + w2_zp=w2_zp, + a1q_scale=a1q_scale, + a2_scale=a2_scale, + workspace13=workspace13, + workspace2=workspace2, + expert_tokens_meta=expert_tokens_meta, + apply_router_weight_on_input=apply_router_weight_on_input) return fused_out def _maybe_chunk_fused_experts( - self, a1: torch.Tensor, a1q: torch.Tensor, w1: torch.Tensor, - w2: torch.Tensor, topk_ids: torch.Tensor, activation: str, - global_num_experts: int, local_num_experts: int, - expert_map: Optional[torch.Tensor], - w1_scale: Optional[torch.Tensor], w2_scale: Optional[torch.Tensor], - w1_zp: Optional[torch.Tensor], w2_zp: Optional[torch.Tensor], - a1q_scale: Optional[torch.Tensor], - a2_scale: Optional[torch.Tensor], - expert_tokens_meta: Optional[ExpertTokensMetadata] + self, + a1: torch.Tensor, + a1q: torch.Tensor, + w1: torch.Tensor, + w2: torch.Tensor, + topk_weights: torch.Tensor, + topk_ids: torch.Tensor, + activation: str, + global_num_experts: int, + local_num_experts: int, + expert_map: Optional[torch.Tensor], + w1_scale: Optional[torch.Tensor], + w2_scale: Optional[torch.Tensor], + w1_zp: Optional[torch.Tensor], + w2_zp: Optional[torch.Tensor], + a1q_scale: Optional[torch.Tensor], + a2_scale: Optional[torch.Tensor], + expert_tokens_meta: Optional[ExpertTokensMetadata], + apply_router_weight_on_input: bool, ) -> torch.Tensor: _, M, N, K, top_k = _moe_problem_size(a1q, w1, w2, topk_ids) @@ -529,6 +553,7 @@ class FusedMoEModularKernel(torch.nn.Module): a1q=a1q, w1=w1, w2=w2, + topk_weights=topk_weights, topk_ids=topk_ids, activation=activation, global_num_experts=global_num_experts, @@ -540,7 +565,8 @@ class FusedMoEModularKernel(torch.nn.Module): w2_zp=w2_zp, a1q_scale=a1q_scale, a2_scale=a2_scale, - expert_tokens_meta=expert_tokens_meta) + expert_tokens_meta=expert_tokens_meta, + apply_router_weight_on_input=apply_router_weight_on_input) # Chunking required case assert num_chunks > 1 @@ -557,11 +583,12 @@ class FusedMoEModularKernel(torch.nn.Module): def slice_input_tensors( chunk_idx: int ) -> tuple[torch.Tensor, Optional[torch.Tensor], - Optional[torch.Tensor], torch.Tensor]: + Optional[torch.Tensor], torch.Tensor, torch.Tensor]: s = chunk_idx * CHUNK_SIZE e = min(s + CHUNK_SIZE, M) return (a1q[s:e], _chunk_scales(a1q_scale, s, e), - _chunk_scales(a2_scale, s, e), topk_ids[s:e]) + _chunk_scales(a2_scale, s, + e), topk_ids[s:e], topk_weights[s:e]) def slice_output_tensor(chunk_idx: int) -> torch.Tensor: assert fused_out.size(0) % M == 0, ( @@ -594,7 +621,7 @@ class FusedMoEModularKernel(torch.nn.Module): expert_num_tokens_cpu=c_expert_num_tokens_cpu) for chunk_idx in range(num_chunks): - c_a1q, c_a1q_scale, c_a2_scale, c_topk_ids = ( + c_a1q, c_a1q_scale, c_a2_scale, c_topk_ids, c_topk_weights = ( slice_input_tensors(chunk_idx)) c_expert_tokens_meta = None @@ -603,23 +630,26 @@ class FusedMoEModularKernel(torch.nn.Module): expert_tokens_meta, c_topk_ids, local_num_experts, expert_map) - self._do_fused_experts(fused_out=slice_output_tensor(chunk_idx), - a1=a1, - a1q=c_a1q, - w1=w1, - w2=w2, - topk_ids=c_topk_ids, - activation=activation, - global_num_experts=global_num_experts, - local_num_experts=local_num_experts, - expert_map=expert_map, - w1_scale=w1_scale, - w2_scale=w2_scale, - w1_zp=w1_zp, - w2_zp=w2_zp, - a1q_scale=c_a1q_scale, - a2_scale=c_a2_scale, - expert_tokens_meta=c_expert_tokens_meta) + self._do_fused_experts( + fused_out=slice_output_tensor(chunk_idx), + a1=a1, + a1q=c_a1q, + w1=w1, + w2=w2, + topk_weights=c_topk_weights, + topk_ids=c_topk_ids, + activation=activation, + global_num_experts=global_num_experts, + local_num_experts=local_num_experts, + expert_map=expert_map, + w1_scale=w1_scale, + w2_scale=w2_scale, + w1_zp=w1_zp, + w2_zp=w2_zp, + a1q_scale=c_a1q_scale, + a2_scale=c_a2_scale, + expert_tokens_meta=c_expert_tokens_meta, + apply_router_weight_on_input=apply_router_weight_on_input) return fused_out @@ -719,6 +749,7 @@ class FusedMoEModularKernel(torch.nn.Module): a1q=a1q, w1=w1, w2=w2, + topk_weights=topk_weights, topk_ids=topk_ids, activation=activation, global_num_experts=global_num_experts, @@ -730,7 +761,8 @@ class FusedMoEModularKernel(torch.nn.Module): w2_zp=w2_zp, a1q_scale=a1q_scale, a2_scale=a2_scale, - expert_tokens_meta=expert_tokens_meta) + expert_tokens_meta=expert_tokens_meta, + apply_router_weight_on_input=apply_router_weight_on_input) self.prepare_finalize.finalize( output, fused_out, topk_weights, topk_ids, diff --git a/vllm/model_executor/layers/fused_moe/topk_weight_and_reduce.py b/vllm/model_executor/layers/fused_moe/topk_weight_and_reduce.py index 9a5315b8b6f7e..fb398eec119fa 100644 --- a/vllm/model_executor/layers/fused_moe/topk_weight_and_reduce.py +++ b/vllm/model_executor/layers/fused_moe/topk_weight_and_reduce.py @@ -48,11 +48,18 @@ class TopKWeightAndReduceNoOP(mk.TopKWeightAndReduce): fused_expert_output: torch.Tensor, topk_weights: torch.Tensor, topk_ids: torch.Tensor, apply_router_weight_on_input: bool) -> torch.Tensor: - # Relax this if an explicit copy is necessary. Note that, - # if a copy is employed we have to make sure that the - # tensors don't overlap - assert output is None - return fused_expert_output + # Weight application and reduction operations are already done. + if output is None: + return fused_expert_output + + # MoEPrepareAndFinalizeNoEP needs the output to be in the `output` + # tensor. + assert output.size() == fused_expert_output.size(), ( + "output shape is expected to match the fused_expert_output shape. " + f"But got output={output.size()}, " + f"used_expert_output={fused_expert_output.size()}") + output.copy_(fused_expert_output, non_blocking=True) + return output class TopKWeightAndReduceContiguous(mk.TopKWeightAndReduce): diff --git a/vllm/model_executor/layers/fused_moe/triton_deep_gemm_moe.py b/vllm/model_executor/layers/fused_moe/triton_deep_gemm_moe.py index fefe74cc4ae0b..2f35c19b70541 100644 --- a/vllm/model_executor/layers/fused_moe/triton_deep_gemm_moe.py +++ b/vllm/model_executor/layers/fused_moe/triton_deep_gemm_moe.py @@ -122,6 +122,7 @@ class TritonOrDeepGemmExperts(mk.FusedMoEPermuteExpertsUnpermute): hidden_states: torch.Tensor, w1: torch.Tensor, w2: torch.Tensor, + topk_weights: torch.Tensor, topk_ids: torch.Tensor, activation: str, global_num_experts: int, @@ -135,6 +136,7 @@ class TritonOrDeepGemmExperts(mk.FusedMoEPermuteExpertsUnpermute): workspace13: torch.Tensor, workspace2: torch.Tensor, expert_tokens_meta: Optional[mk.ExpertTokensMetadata], + apply_router_weight_on_input: bool, ): use_deep_gemm = (self.allow_deep_gemm and (_valid_deep_gemm(hidden_states, w1, w2) @@ -148,6 +150,7 @@ class TritonOrDeepGemmExperts(mk.FusedMoEPermuteExpertsUnpermute): hidden_states, w1, w2, + topk_weights, topk_ids, activation, global_num_experts, @@ -161,4 +164,5 @@ class TritonOrDeepGemmExperts(mk.FusedMoEPermuteExpertsUnpermute): workspace13, workspace2, expert_tokens_meta, + apply_router_weight_on_input, )