From f0c98cae2758b1a706537aa412c6868bb060c151 Mon Sep 17 00:00:00 2001 From: Varun Sundar Rabindranath Date: Thu, 10 Jul 2025 17:40:38 -0400 Subject: [PATCH] [Misc] MoE ModularKernel : Introduce TopKWeightAndReduce (#20648) Signed-off-by: Varun Sundar Rabindranath Co-authored-by: Varun Sundar Rabindranath --- tests/kernels/moe/test_pplx_moe.py | 3 + .../layers/fused_moe/batched_deep_gemm_moe.py | 6 + .../batched_triton_or_deep_gemm_moe.py | 19 +++ .../layers/fused_moe/cutlass_moe.py | 6 + .../layers/fused_moe/deep_gemm_moe.py | 6 + .../fused_moe/deepep_ht_prepare_finalize.py | 39 ++--- .../fused_moe/deepep_ll_prepare_finalize.py | 9 +- .../layers/fused_moe/fused_batched_moe.py | 38 ++--- .../layers/fused_moe/fused_moe.py | 6 + .../layers/fused_moe/modular_kernel.py | 44 +++++- .../layers/fused_moe/pplx_prepare_finalize.py | 7 + .../layers/fused_moe/prepare_finalize.py | 15 +- .../fused_moe/topk_weight_and_reduce.py | 139 ++++++++++++++++++ .../layers/fused_moe/triton_deep_gemm_moe.py | 19 +++ 14 files changed, 297 insertions(+), 59 deletions(-) create mode 100644 vllm/model_executor/layers/fused_moe/topk_weight_and_reduce.py diff --git a/tests/kernels/moe/test_pplx_moe.py b/tests/kernels/moe/test_pplx_moe.py index d28e0e040629d..f7a661b4bc7b1 100644 --- a/tests/kernels/moe/test_pplx_moe.py +++ b/tests/kernels/moe/test_pplx_moe.py @@ -32,6 +32,8 @@ from vllm.model_executor.layers.fused_moe.fused_batched_moe import ( from vllm.model_executor.layers.fused_moe.fused_moe import get_default_config from vllm.model_executor.layers.fused_moe.modular_kernel import ( FusedMoEModularKernel) +from vllm.model_executor.layers.fused_moe.topk_weight_and_reduce import ( + TopKWeightAndReduceDelegate) from vllm.platforms import current_platform from vllm.utils import round_up @@ -371,6 +373,7 @@ def pplx_prepare_finalize( chunk_topk_weight, chunk_topk_ids, False, + weight_and_reduce_impl=TopKWeightAndReduceDelegate(), ) torch.cuda.synchronize() diff --git a/vllm/model_executor/layers/fused_moe/batched_deep_gemm_moe.py b/vllm/model_executor/layers/fused_moe/batched_deep_gemm_moe.py index 22de5a026cf04..751ed6abd999a 100644 --- a/vllm/model_executor/layers/fused_moe/batched_deep_gemm_moe.py +++ b/vllm/model_executor/layers/fused_moe/batched_deep_gemm_moe.py @@ -7,6 +7,8 @@ import torch import vllm.model_executor.layers.fused_moe.modular_kernel as mk from vllm.logger import init_logger from vllm.model_executor.layers.fused_moe.config import FusedMoEQuantConfig +from vllm.model_executor.layers.fused_moe.topk_weight_and_reduce import ( + TopKWeightAndReduceDelegate) from vllm.model_executor.layers.fused_moe.utils import _resize_cache from vllm.triton_utils import tl, triton @@ -217,6 +219,10 @@ class BatchedDeepGemmExperts(mk.FusedMoEPermuteExpertsUnpermute): def supports_expert_map(self) -> bool: return False + def finalize_weight_and_reduce_impl(self) -> mk.TopKWeightAndReduce: + # Let PrepareAndFinalize::finalize() decide the impl. + return TopKWeightAndReduceDelegate() + def workspace_shapes( self, a: torch.Tensor, diff --git a/vllm/model_executor/layers/fused_moe/batched_triton_or_deep_gemm_moe.py b/vllm/model_executor/layers/fused_moe/batched_triton_or_deep_gemm_moe.py index 76adfed9ca1ce..66abd8d7db7bf 100644 --- a/vllm/model_executor/layers/fused_moe/batched_triton_or_deep_gemm_moe.py +++ b/vllm/model_executor/layers/fused_moe/batched_triton_or_deep_gemm_moe.py @@ -88,6 +88,25 @@ class BatchedTritonOrDeepGemmExperts(mk.FusedMoEPermuteExpertsUnpermute): return ((bdge is None or bdge.supports_expert_map()) and (bte is None or bte.supports_expert_map())) + def finalize_weight_and_reduce_impl(self) -> mk.TopKWeightAndReduce: + bdge = self.batched_deep_gemm_experts + bte = self.batched_triton_experts + bdge_war = bdge.finalize_weight_and_reduce_impl() if bdge else None + bte_war = bte.finalize_weight_and_reduce_impl() if bte else None + is_bdge_war = bdge_war is not None + is_bte_war = bte_war is not None + + if is_bdge_war and is_bte_war: + assert bdge_war == bte_war, ( + "Both implementations should agree on WeightAndReduce impls. " + f"Got bdge_war: {bdge_war}, and bte_war: {bte_war}") + + if bdge_war is not None: + return bdge_war + + assert bte_war is not None + return bte_war + def workspace_shapes( self, a: torch.Tensor, diff --git a/vllm/model_executor/layers/fused_moe/cutlass_moe.py b/vllm/model_executor/layers/fused_moe/cutlass_moe.py index c8a8415baf238..623003f65adaa 100644 --- a/vllm/model_executor/layers/fused_moe/cutlass_moe.py +++ b/vllm/model_executor/layers/fused_moe/cutlass_moe.py @@ -11,6 +11,8 @@ from vllm.logger import init_logger from vllm.model_executor.layers.fused_moe.config import FusedMoEQuantConfig from vllm.model_executor.layers.fused_moe.prepare_finalize import ( MoEPrepareAndFinalizeNoEP) +from vllm.model_executor.layers.fused_moe.topk_weight_and_reduce import ( + TopKWeightAndReduceDelegate) from vllm.model_executor.layers.fused_moe.utils import (_fp8_perm, _fp8_quantize, _resize_cache) @@ -255,6 +257,10 @@ class CutlassExpertsFp8(mk.FusedMoEPermuteExpertsUnpermute): def supports_expert_map(self) -> bool: return not self.use_batched_format + def finalize_weight_and_reduce_impl(self) -> mk.TopKWeightAndReduce: + # Let PrepareAndFinalize::finalize() decide the impl. + return TopKWeightAndReduceDelegate() + def workspace_shapes( self, a: torch.Tensor, diff --git a/vllm/model_executor/layers/fused_moe/deep_gemm_moe.py b/vllm/model_executor/layers/fused_moe/deep_gemm_moe.py index 40b58f1a4ad91..fdeac43902f96 100644 --- a/vllm/model_executor/layers/fused_moe/deep_gemm_moe.py +++ b/vllm/model_executor/layers/fused_moe/deep_gemm_moe.py @@ -12,6 +12,8 @@ from vllm.model_executor.layers.fused_moe.moe_permute_unpermute import ( _moe_permute) from vllm.model_executor.layers.fused_moe.prepare_finalize import ( MoEPrepareAndFinalizeNoEP) +from vllm.model_executor.layers.fused_moe.topk_weight_and_reduce import ( + TopKWeightAndReduceDelegate) from vllm.model_executor.layers.fused_moe.utils import ( _resize_cache, per_token_group_quant_fp8) from vllm.utils import has_deep_gemm, round_up @@ -85,6 +87,10 @@ class DeepGemmExperts(mk.FusedMoEPermuteExpertsUnpermute): def supports_expert_map(self) -> bool: return True + def finalize_weight_and_reduce_impl(self) -> mk.TopKWeightAndReduce: + # Let PrepareAndFinalize::finalize() decide the impl. + return TopKWeightAndReduceDelegate() + def workspace_shapes( self, a: torch.Tensor, aq: torch.Tensor, M: int, N: int, K: int, topk: int, global_num_experts: int, local_num_experts: int diff --git a/vllm/model_executor/layers/fused_moe/deepep_ht_prepare_finalize.py b/vllm/model_executor/layers/fused_moe/deepep_ht_prepare_finalize.py index 8ed42975a32ea..e10927c4dce51 100644 --- a/vllm/model_executor/layers/fused_moe/deepep_ht_prepare_finalize.py +++ b/vllm/model_executor/layers/fused_moe/deepep_ht_prepare_finalize.py @@ -6,8 +6,9 @@ import deep_ep import torch import vllm.model_executor.layers.fused_moe.modular_kernel as mk -from vllm import _custom_ops as ops from vllm.model_executor.layers.fused_moe.config import FusedMoEQuantConfig +from vllm.model_executor.layers.fused_moe.topk_weight_and_reduce import ( + TopKWeightAndReduceContiguous, TopKWeightAndReduceDelegate) from vllm.model_executor.layers.fused_moe.utils import ( moe_kernel_quantize_input) @@ -187,45 +188,25 @@ class DeepEPHTPrepareAndFinalize(mk.FusedMoEPrepareAndFinalize): return (expert_x, expert_x_scale, expert_tokens_meta, expert_topk_ids, expert_topk_weights) - def _apply_weights_and_reduce(self, num_tokens: int, - fused_expert_output: torch.Tensor, - topk_weights: torch.Tensor, - apply_router_weight_on_input: bool, - output_dtype: torch.dtype): - - hidden_dim = fused_expert_output.size(-1) - if fused_expert_output.ndim == 2: - fused_expert_output = fused_expert_output.view( - num_tokens, -1, hidden_dim) - - if not apply_router_weight_on_input: - # The DeepEP combine kernels don't do the topk weight - # multiplication. We multiply the weights locally. - m_x_topk = fused_expert_output.size(0) - fused_expert_output.mul_(topk_weights.view(m_x_topk, -1, 1)) - - out = torch.empty((num_tokens, hidden_dim), - device=fused_expert_output.device, - dtype=output_dtype) - ops.moe_sum(fused_expert_output, out) - - return out - def finalize(self, output: torch.Tensor, fused_expert_output: torch.Tensor, topk_weights: torch.Tensor, topk_ids: torch.Tensor, - apply_router_weight_on_input: bool) -> None: + apply_router_weight_on_input: bool, + weight_and_reduce_impl: mk.TopKWeightAndReduce) -> None: assert self.handle is not None # fused_expert_output can have 0 tokens - This happens when none of the # tokens from the all2all reach this EP rank. if fused_expert_output.numel() != 0: - fused_expert_output = self._apply_weights_and_reduce( - num_tokens=topk_ids.size(0), + if isinstance(weight_and_reduce_impl, TopKWeightAndReduceDelegate): + weight_and_reduce_impl = TopKWeightAndReduceContiguous() + fused_expert_output = weight_and_reduce_impl.apply( + output=None, fused_expert_output=fused_expert_output, topk_weights=topk_weights, + topk_ids=topk_ids, apply_router_weight_on_input=apply_router_weight_on_input, - output_dtype=output.dtype) + ) combined_x, _, event = self.buffer.combine( x=fused_expert_output, diff --git a/vllm/model_executor/layers/fused_moe/deepep_ll_prepare_finalize.py b/vllm/model_executor/layers/fused_moe/deepep_ll_prepare_finalize.py index 38c33203abfb9..b04f01975849c 100644 --- a/vllm/model_executor/layers/fused_moe/deepep_ll_prepare_finalize.py +++ b/vllm/model_executor/layers/fused_moe/deepep_ll_prepare_finalize.py @@ -7,6 +7,8 @@ import torch import vllm.model_executor.layers.fused_moe.modular_kernel as mk from vllm.model_executor.layers.fused_moe.config import FusedMoEQuantConfig +from vllm.model_executor.layers.fused_moe.topk_weight_and_reduce import ( + TopKWeightAndReduceDelegate) from vllm.model_executor.layers.fused_moe.utils import ( moe_kernel_quantize_input, normalize_batched_scales_shape) @@ -166,8 +168,11 @@ class DeepEPLLPrepareAndFinalize(mk.FusedMoEPrepareAndFinalize): def finalize(self, output: torch.Tensor, fused_expert_output: torch.Tensor, topk_weights: torch.Tensor, topk_ids: torch.Tensor, - apply_router_weight_on_input: bool) -> None: - + apply_router_weight_on_input: bool, + weight_and_reduce_impl: mk.TopKWeightAndReduce) -> None: + assert isinstance( + weight_and_reduce_impl, TopKWeightAndReduceDelegate + ), ("Weight application and reduction happens in the combine kernel.") assert self.handle is not None combine_topk_weights = topk_weights diff --git a/vllm/model_executor/layers/fused_moe/fused_batched_moe.py b/vllm/model_executor/layers/fused_moe/fused_batched_moe.py index 591f6b681d3d6..34f8c124759a8 100644 --- a/vllm/model_executor/layers/fused_moe/fused_batched_moe.py +++ b/vllm/model_executor/layers/fused_moe/fused_batched_moe.py @@ -11,6 +11,8 @@ import vllm.model_executor.layers.fused_moe.modular_kernel as mk from vllm.model_executor.layers.fused_moe.config import FusedMoEQuantConfig from vllm.model_executor.layers.fused_moe.fused_moe import ( get_config_dtype_str, try_get_optimal_moe_config) +from vllm.model_executor.layers.fused_moe.topk_weight_and_reduce import ( + TopKWeightAndReduceDelegate, TopKWeightAndReduceNaiveBatched) from vllm.model_executor.layers.fused_moe.utils import ( _resize_cache, moe_kernel_quantize_input, normalize_batched_scales_shape, normalize_scales_shape) @@ -600,25 +602,17 @@ class BatchedPrepareAndFinalize(mk.FusedMoEPrepareAndFinalize): topk_weights: torch.Tensor, topk_ids: torch.Tensor, apply_router_weight_on_input: bool, + weight_and_reduce_impl: mk.TopKWeightAndReduce, ) -> None: - num_tokens = topk_ids.size(0) - num_local_experts = fused_expert_output.size(0) - K = fused_expert_output.size(-1) - assert output.size(0) == num_tokens and output.size(1) == K - - output.fill_(0) - - first_expert = num_local_experts * self.rank - last_expert = first_expert + num_local_experts - - for expert_id in range(first_expert, last_expert): - matching_tokens = topk_ids == expert_id - topks = torch.any(matching_tokens, dim=1).flatten() - rows = torch.count_nonzero(topks) - rhs = fused_expert_output[expert_id - first_expert, :rows, :] - if not apply_router_weight_on_input: - rhs.mul_(topk_weights[matching_tokens].view(rhs.size(0), 1)) - output[topks] = output[topks] + rhs + if isinstance(weight_and_reduce_impl, TopKWeightAndReduceDelegate): + weight_and_reduce_impl = TopKWeightAndReduceNaiveBatched(self.rank) + weight_and_reduce_impl.apply( + output=output, + fused_expert_output=fused_expert_output, + topk_weights=topk_weights, + topk_ids=topk_ids, + apply_router_weight_on_input=apply_router_weight_on_input, + ) class NaiveBatchedExperts(mk.FusedMoEPermuteExpertsUnpermute): @@ -670,6 +664,10 @@ class NaiveBatchedExperts(mk.FusedMoEPermuteExpertsUnpermute): def supports_expert_map(self) -> bool: return False + def finalize_weight_and_reduce_impl(self) -> mk.TopKWeightAndReduce: + # Let PrepareAndFinalize::finalize() decide the impl. + return TopKWeightAndReduceDelegate() + def workspace_shapes( self, a: torch.Tensor, @@ -877,6 +875,10 @@ class BatchedTritonExperts(mk.FusedMoEPermuteExpertsUnpermute): def supports_expert_map(self) -> bool: return False + def finalize_weight_and_reduce_impl(self) -> mk.TopKWeightAndReduce: + # Let PrepareAndFinalize::finalize() decide the impl. + return TopKWeightAndReduceDelegate() + def workspace_shapes( self, a: torch.Tensor, diff --git a/vllm/model_executor/layers/fused_moe/fused_moe.py b/vllm/model_executor/layers/fused_moe/fused_moe.py index 26eeed1cd07f4..1947a3d5fac11 100644 --- a/vllm/model_executor/layers/fused_moe/fused_moe.py +++ b/vllm/model_executor/layers/fused_moe/fused_moe.py @@ -25,6 +25,8 @@ from vllm.model_executor.layers.fused_moe.moe_align_block_size import ( moe_align_block_size) from vllm.model_executor.layers.fused_moe.prepare_finalize import ( MoEPrepareAndFinalizeNoEP) +from vllm.model_executor.layers.fused_moe.topk_weight_and_reduce import ( + TopKWeightAndReduceDelegate) from vllm.model_executor.layers.fused_moe.utils import ( _resize_cache, moe_kernel_quantize_input) from vllm.model_executor.layers.quantization.utils.mxfp4_utils import ( @@ -1596,6 +1598,10 @@ class TritonExperts(mk.FusedMoEPermuteExpertsUnpermute): def supports_expert_map(self) -> bool: return True + def finalize_weight_and_reduce_impl(self) -> mk.TopKWeightAndReduce: + # Let PrepareAndFinalize::finalize() decide the impl. + return TopKWeightAndReduceDelegate() + def workspace_shapes( self, a: torch.Tensor, diff --git a/vllm/model_executor/layers/fused_moe/modular_kernel.py b/vllm/model_executor/layers/fused_moe/modular_kernel.py index 8453ab0dc951a..d0d8c7d6f41e9 100644 --- a/vllm/model_executor/layers/fused_moe/modular_kernel.py +++ b/vllm/model_executor/layers/fused_moe/modular_kernel.py @@ -23,7 +23,7 @@ from vllm.utils import cdiv # # [Router] → [Quantize-Dispatch] → [Permute-Experts-Unpermute] → [Combine] # -# Each component will be independent of the others except for +# Each component will be independent of (but may inform) the others except for # [Quantize-Dispatch] and `[Combine] (see below). The components can then be # mixed and matched with so that DP+EP can be supported easily for multiple # MoE kernel implementations. @@ -32,13 +32,19 @@ from vllm.utils import cdiv # * FusedMoEPrepareAndFinalize - an abstract base class for preparation of MoE # inputs (e.g. quantization, distribution) and finalization of Moe outputs. # The prepare method must take care of any needed quantization and the -# finalize method must apply weights and do the final reduction of the output. +# finalize method, informed by the FusedMoEPermuteExpertsUnpermute method, +# may apply weights and/or do the final reduction of the output. # * FusedMoEPermuteExpertsUnpermute - an abstract base class for the main fused -# MoE operation. One important feature to note is that this class does not -# apply topk weights or reduce the final output. +# MoE operation, i.e matmul + act_mul + optionally quant + matmul. +# Some FusedMoEPermuteExpertsUnpermute implementations may choose to do +# the weight application and/or reduction. The class communicates this +# to [Finalize] via a TopKWeightAndReduce object. # * FusedMoEModularKernel - an interface class that combines a # FusedMoEPrepareAndFinalize and a FusedMoEPermuteExpertsUnpermute to # provide the standard fused MoE kernel interface. +# * TopKWeightAndReduce - A TopKWeightAndReduce implementation chosen +# by the FusedMoEPermuteExpertsUnpermute implementation that is passed +# on to [Finalize]. # # [Quantize-Prepare] and [Finalize] functionality are bundled into a single # class `FusedMoEPrepareAndFinalize` since they could use collective @@ -117,6 +123,24 @@ class ExpertTokensMetadata: expert_num_tokens_cpu=expert_num_tokens_cpu) +class TopKWeightAndReduce(ABC): + """ + An abstract base class for weight application and reduction implementations. + """ + + @abstractmethod + def apply(self, output: Optional[torch.Tensor], + fused_expert_output: torch.Tensor, topk_weights: torch.Tensor, + topk_ids: torch.Tensor, + apply_router_weight_on_input: bool) -> torch.Tensor: + """ + Apply topk_weights to the fused_experts_outputs and/or reduce. + If an output tensor is not passed, it will be created in the + function. + """ + raise NotImplementedError + + # TODO: pass FusedMoEParallelConfig in as ctor parameter? class FusedMoEPrepareAndFinalize(ABC): """ @@ -173,6 +197,7 @@ class FusedMoEPrepareAndFinalize(ABC): topk_weights: torch.Tensor, topk_ids: torch.Tensor, apply_router_weight_on_input: bool, + weight_and_reduce_impl: TopKWeightAndReduce, ) -> None: """ Perform any combine plus apply weights and perform a reduction on the @@ -184,6 +209,8 @@ class FusedMoEPrepareAndFinalize(ABC): - topk_ids: The topk_ids. - apply_router_weight_on_input: When False, apply the weights to fused_expert_output. + - weight_and_reduce_impl: An optional TopKWeightAndReduce + implementation. """ raise NotImplementedError @@ -323,6 +350,9 @@ class FusedMoEPermuteExpertsUnpermute(ABC): return envs.VLLM_ENABLE_FUSED_MOE_ACTIVATION_CHUNKING and \ self.supports_chunking() + def finalize_weight_and_reduce_impl(self) -> TopKWeightAndReduce: + raise NotImplementedError + @abstractmethod def apply( self, @@ -702,7 +732,9 @@ class FusedMoEModularKernel(torch.nn.Module): a2_scale=a2_scale, expert_tokens_meta=expert_tokens_meta) - self.prepare_finalize.finalize(output, fused_out, topk_weights, - topk_ids, apply_router_weight_on_input) + self.prepare_finalize.finalize( + output, fused_out, topk_weights, topk_ids, + apply_router_weight_on_input, + self.fused_experts.finalize_weight_and_reduce_impl()) return output diff --git a/vllm/model_executor/layers/fused_moe/pplx_prepare_finalize.py b/vllm/model_executor/layers/fused_moe/pplx_prepare_finalize.py index 1ce47e3eeca3c..46f1231a617b8 100644 --- a/vllm/model_executor/layers/fused_moe/pplx_prepare_finalize.py +++ b/vllm/model_executor/layers/fused_moe/pplx_prepare_finalize.py @@ -8,6 +8,8 @@ import torch import vllm.model_executor.layers.fused_moe.modular_kernel as mk from vllm.logger import init_logger from vllm.model_executor.layers.fused_moe.config import FusedMoEQuantConfig +from vllm.model_executor.layers.fused_moe.topk_weight_and_reduce import ( + TopKWeightAndReduceDelegate) from vllm.model_executor.layers.fused_moe.utils import ( _validate_scale_shape, moe_kernel_quantize_input) from vllm.utils import cdiv, round_up @@ -222,7 +224,12 @@ class PplxPrepareAndFinalize(mk.FusedMoEPrepareAndFinalize): topk_weights: torch.Tensor, topk_ids: torch.Tensor, apply_router_weight_on_input: bool, + weight_and_reduce_impl: mk.TopKWeightAndReduce, ) -> None: + assert isinstance( + weight_and_reduce_impl, TopKWeightAndReduceDelegate + ), ("Weight application and reduction happens in the combine kernel.") + # This argument is optional # There's not much point setting this unless it is != topk_ids.size(0) bound_m: Optional[torch.Tensor] = None diff --git a/vllm/model_executor/layers/fused_moe/prepare_finalize.py b/vllm/model_executor/layers/fused_moe/prepare_finalize.py index d413d2ce0e23c..567a0a88fec0a 100644 --- a/vllm/model_executor/layers/fused_moe/prepare_finalize.py +++ b/vllm/model_executor/layers/fused_moe/prepare_finalize.py @@ -6,8 +6,8 @@ import torch import vllm.model_executor.layers.fused_moe.modular_kernel as mk from vllm.model_executor.layers.fused_moe.config import FusedMoEQuantConfig -from vllm.model_executor.layers.fused_moe.moe_permute_unpermute import ( - _moe_unpermute_and_reduce) +from vllm.model_executor.layers.fused_moe.topk_weight_and_reduce import ( + TopKWeightAndReduceContiguous, TopKWeightAndReduceDelegate) from vllm.model_executor.layers.fused_moe.utils import ( moe_kernel_quantize_input) @@ -62,6 +62,13 @@ class MoEPrepareAndFinalizeNoEP(mk.FusedMoEPrepareAndFinalize): topk_weights: torch.Tensor, topk_ids: torch.Tensor, apply_router_weight_on_input: bool, + weight_and_reduce_impl: mk.TopKWeightAndReduce, ) -> None: - _moe_unpermute_and_reduce(output, fused_expert_output, None, - topk_weights, apply_router_weight_on_input) + if isinstance(weight_and_reduce_impl, TopKWeightAndReduceDelegate): + weight_and_reduce_impl = TopKWeightAndReduceContiguous() + weight_and_reduce_impl.apply( + output=output, + fused_expert_output=fused_expert_output, + topk_weights=topk_weights, + topk_ids=topk_ids, + apply_router_weight_on_input=apply_router_weight_on_input) diff --git a/vllm/model_executor/layers/fused_moe/topk_weight_and_reduce.py b/vllm/model_executor/layers/fused_moe/topk_weight_and_reduce.py new file mode 100644 index 0000000000000..9a5315b8b6f7e --- /dev/null +++ b/vllm/model_executor/layers/fused_moe/topk_weight_and_reduce.py @@ -0,0 +1,139 @@ +# SPDX-License-Identifier: Apache-2.0 +# SPDX-FileCopyrightText: Copyright contributors to the vLLM project + +from typing import Optional + +import torch + +import vllm._custom_ops as ops +import vllm.model_executor.layers.fused_moe.modular_kernel as mk + + +class TopKWeightAndReduceDelegate(mk.TopKWeightAndReduce): + """ + Useful in the case when some FusedMoEPermuteExpertsUnpermute + implementation does not perform weight application and reduction + but cannot address the needs of all the compatible PrepareAndFinalize + implementations. + For example, BatchedTritonExperts is compatible with both + PplxPrepareAndFinalize and BatchedPrepareAndFinalize. PplxPrepareAndFinalize + does the weight-application + reduction as part of the pplx combine kernel. + But the BatchedPrepareAndFinalize needs an implementation. To facilitate + this case, the BatchedTritonExperts could use TopKWeightAndReduceDelegate + so the PrepareAndFinalize implementations could choose how to + weight + reduce. + """ + + def __eq__(self, other): + return isinstance(other, TopKWeightAndReduceDelegate) + + def apply(self, output: Optional[torch.Tensor], + fused_expert_output: torch.Tensor, topk_weights: torch.Tensor, + topk_ids: torch.Tensor, + apply_router_weight_on_input: bool) -> torch.Tensor: + raise RuntimeError("The caller is expected to choose an appropriate " + "TopKWeightAndReduce implementation.") + + +class TopKWeightAndReduceNoOP(mk.TopKWeightAndReduce): + """ + The fused_experts outputs have already been weight applied and reduced. + This implementation is a no-op. + """ + + def __eq__(self, other): + return isinstance(other, TopKWeightAndReduceNoOP) + + def apply(self, output: Optional[torch.Tensor], + fused_expert_output: torch.Tensor, topk_weights: torch.Tensor, + topk_ids: torch.Tensor, + apply_router_weight_on_input: bool) -> torch.Tensor: + # Relax this if an explicit copy is necessary. Note that, + # if a copy is employed we have to make sure that the + # tensors don't overlap + assert output is None + return fused_expert_output + + +class TopKWeightAndReduceContiguous(mk.TopKWeightAndReduce): + """ + TopKWeightAndReduce implementation for a fused_experts output + of shape (m, topk, K) + """ + + def __eq__(self, other): + return isinstance(other, TopKWeightAndReduceContiguous) + + def apply(self, output: Optional[torch.Tensor], + fused_expert_output: torch.Tensor, topk_weights: torch.Tensor, + topk_ids: torch.Tensor, + apply_router_weight_on_input: bool) -> torch.Tensor: + + m, num_topk = topk_ids.size() + k = fused_expert_output.size(-1) + if fused_expert_output.ndim == 2: + fused_expert_output = fused_expert_output.view(m, num_topk, k) + + assert fused_expert_output.size() == (m, num_topk, k), ( + f"Expected fused_expert_output size {(m, num_topk, k)}. But got " + f"{fused_expert_output.size()}") + + if not apply_router_weight_on_input: + fused_expert_output.mul_(topk_weights.view(m, -1, 1)) + + if output is None: + output = torch.empty((m, k), + device=fused_expert_output.device, + dtype=fused_expert_output.dtype) + assert output.size() == (m, k), ( + f"Expected output size {(m, k)}. But got {output.size()}") + + ops.moe_sum(fused_expert_output, output) + return output + + +class TopKWeightAndReduceNaiveBatched(mk.TopKWeightAndReduce): + """ + TopKWeightAndReduce implementation for a fused_experts output + of shape (num_experts, batch_size, K) + """ + + def __init__(self, rank: int): + self.rank = rank + + def __eq__(self, other): + return (isinstance(other, TopKWeightAndReduceNaiveBatched) + and (other.rank == self.rank)) + + def apply(self, output: Optional[torch.Tensor], + fused_expert_output: torch.Tensor, topk_weights: torch.Tensor, + topk_ids: torch.Tensor, + apply_router_weight_on_input: bool) -> torch.Tensor: + assert fused_expert_output.ndim == 3 + num_tokens = topk_ids.size(0) + num_local_experts = fused_expert_output.size(0) + K = fused_expert_output.size(-1) + + if output is None: + output = torch.zeros((num_tokens, K), + device=fused_expert_output.device, + dtype=fused_expert_output.dtype) + else: + output.fill_(0) + + assert output.size() == (num_tokens, K), ( + f"Expected output size {(num_tokens, K)}, but got {output.size()}") + + first_expert = num_local_experts * self.rank + last_expert = first_expert + num_local_experts + + for expert_id in range(first_expert, last_expert): + matching_tokens = topk_ids == expert_id + topks = torch.any(matching_tokens, dim=1).flatten() + rows = torch.count_nonzero(topks) + rhs = fused_expert_output[expert_id - first_expert, :rows, :] + if not apply_router_weight_on_input: + rhs.mul_(topk_weights[matching_tokens].view(rhs.size(0), 1)) + output[topks] = output[topks] + rhs + + return output diff --git a/vllm/model_executor/layers/fused_moe/triton_deep_gemm_moe.py b/vllm/model_executor/layers/fused_moe/triton_deep_gemm_moe.py index 2db7626eba84b..891ffd1c79b45 100644 --- a/vllm/model_executor/layers/fused_moe/triton_deep_gemm_moe.py +++ b/vllm/model_executor/layers/fused_moe/triton_deep_gemm_moe.py @@ -69,6 +69,25 @@ class TritonOrDeepGemmExperts(mk.FusedMoEPermuteExpertsUnpermute): return ((dge is None or dge.supports_expert_map()) and (te is None or te.supports_expert_map())) + def finalize_weight_and_reduce_impl(self) -> mk.TopKWeightAndReduce: + dge = self.deep_gemm_expert + te = self.triton_expert + dge_war = dge.finalize_weight_and_reduce_impl() if dge else None + te_war = te.finalize_weight_and_reduce_impl() if te else None + is_dge_war = dge_war is not None + is_te_war = te_war is not None + + if is_dge_war and is_te_war: + assert dge_war == te_war, ( + "Both implementations should agree on WeightAndReduce impls. " + f"Got dge_war: {dge_war}, and te_war: {te_war}") + + if dge_war is not None: + return dge_war + + assert te_war is not None + return te_war + def workspace_shapes( self, a: torch.Tensor,