From 3778673ea81bf5241f40e9c5e90f989bde377acf Mon Sep 17 00:00:00 2001 From: Wentao Ye <44945378+yewentao256@users.noreply.github.com> Date: Sun, 14 Dec 2025 23:21:36 -0500 Subject: [PATCH] [Feat] Refactor for `parallel_config` in `FusedMoEModularKernel` (#30282) Signed-off-by: yewentao256 Signed-off-by: Robert Shaw Co-authored-by: Robert Shaw Co-authored-by: Robert Shaw <114415538+robertgshaw2-redhat@users.noreply.github.com> --- .../moe/modular_kernel_tools/common.py | 3 ++- tests/kernels/moe/test_flashinfer.py | 14 +++++++++++++ .../layers/fused_moe/cutlass_moe.py | 2 -- .../layers/fused_moe/deep_gemm_moe.py | 2 +- .../fused_moe/fused_moe_modular_method.py | 7 +------ .../layers/fused_moe/modular_kernel.py | 21 ++++++++++++------- .../compressed_tensors_moe.py | 3 --- .../quantization/utils/flashinfer_utils.py | 7 +------ 8 files changed, 32 insertions(+), 27 deletions(-) diff --git a/tests/kernels/moe/modular_kernel_tools/common.py b/tests/kernels/moe/modular_kernel_tools/common.py index d95c22fdf0a5b..6078ce44cee9f 100644 --- a/tests/kernels/moe/modular_kernel_tools/common.py +++ b/tests/kernels/moe/modular_kernel_tools/common.py @@ -594,7 +594,8 @@ def make_modular_kernel( ) modular_kernel = mk.FusedMoEModularKernel( - prepare_finalize=prepare_finalize, fused_experts=fused_experts + prepare_finalize=prepare_finalize, + fused_experts=fused_experts, ) return modular_kernel diff --git a/tests/kernels/moe/test_flashinfer.py b/tests/kernels/moe/test_flashinfer.py index d553e2820e5ff..bf4ef2d30466b 100644 --- a/tests/kernels/moe/test_flashinfer.py +++ b/tests/kernels/moe/test_flashinfer.py @@ -5,6 +5,7 @@ from dataclasses import dataclass import pytest import torch +import vllm.model_executor.layers.fused_moe.modular_kernel as mk from vllm.config import ParallelConfig, VllmConfig, set_current_vllm_config from vllm.model_executor.layers.fused_moe.config import ( FusedMoEQuantConfig, @@ -107,6 +108,19 @@ class TestData: layer.w2_input_scale = a2_scale layer.w13_weight_scale = w13_weight_scale layer.w2_weight_scale = w2_weight_scale + # Setup dummy config. + layer.moe_parallel_config = mk.FusedMoEParallelConfig( + tp_size=1, + pcp_size=1, + dp_size=1, + ep_size=1, + tp_rank=1, + pcp_rank=1, + dp_rank=1, + ep_rank=1, + use_ep=False, + all2all_backend="naive", + ) register_moe_scaling_factors(layer) diff --git a/vllm/model_executor/layers/fused_moe/cutlass_moe.py b/vllm/model_executor/layers/fused_moe/cutlass_moe.py index 552e38a71bf98..4a0b4e82c1b39 100644 --- a/vllm/model_executor/layers/fused_moe/cutlass_moe.py +++ b/vllm/model_executor/layers/fused_moe/cutlass_moe.py @@ -460,7 +460,6 @@ def cutlass_moe_fp8( expert_map: torch.Tensor | None = None, apply_router_weight_on_input: bool = False, global_num_experts: int = -1, - parallel_config=None, ) -> torch.Tensor: """ This function computes a a8w8-quantized Mixture of Experts (MoE) layer @@ -538,7 +537,6 @@ def cutlass_moe_fp8( c_strides2=c_strides2, quant_config=quant_config, ), - parallel_config=parallel_config, ) return fn( diff --git a/vllm/model_executor/layers/fused_moe/deep_gemm_moe.py b/vllm/model_executor/layers/fused_moe/deep_gemm_moe.py index 4a64736ed767b..5ca91768c9760 100644 --- a/vllm/model_executor/layers/fused_moe/deep_gemm_moe.py +++ b/vllm/model_executor/layers/fused_moe/deep_gemm_moe.py @@ -293,7 +293,7 @@ def deep_gemm_moe_fp8( expert_map: torch.Tensor | None = None, a1_scale: torch.Tensor | None = None, a2_scale: torch.Tensor | None = None, - apply_router_weight_on_input=False, + apply_router_weight_on_input: bool = False, ) -> torch.Tensor: """ This function computes a a8w8-quantized Mixture of Experts (MoE) layer diff --git a/vllm/model_executor/layers/fused_moe/fused_moe_modular_method.py b/vllm/model_executor/layers/fused_moe/fused_moe_modular_method.py index 1947423bf4777..9c9bc2514bb4b 100644 --- a/vllm/model_executor/layers/fused_moe/fused_moe_modular_method.py +++ b/vllm/model_executor/layers/fused_moe/fused_moe_modular_method.py @@ -43,11 +43,6 @@ class FusedMoEModularMethod(FusedMoEMethodBase, CustomOp): prepare_finalize: FusedMoEPrepareAndFinalize, shared_experts: torch.nn.Module | None, ) -> "FusedMoEModularMethod": - parallel_config = getattr( - getattr(moe_layer, "vllm_config", None), - "parallel_config", - None, - ) return FusedMoEModularMethod( old_quant_method, FusedMoEModularKernel( @@ -55,7 +50,7 @@ class FusedMoEModularMethod(FusedMoEMethodBase, CustomOp): old_quant_method.select_gemm_impl(prepare_finalize, moe_layer), shared_experts, getattr(moe_layer, "shared_experts_stream", None), - parallel_config=parallel_config, + moe_parallel_config=moe_layer.moe_parallel_config, ), ) diff --git a/vllm/model_executor/layers/fused_moe/modular_kernel.py b/vllm/model_executor/layers/fused_moe/modular_kernel.py index 9e75a7c08070e..484314091cb15 100644 --- a/vllm/model_executor/layers/fused_moe/modular_kernel.py +++ b/vllm/model_executor/layers/fused_moe/modular_kernel.py @@ -10,10 +10,12 @@ from typing import final import torch import vllm.envs as envs -from vllm.config import ParallelConfig, get_current_vllm_config from vllm.forward_context import get_forward_context, is_forward_context_available from vllm.logger import init_logger -from vllm.model_executor.layers.fused_moe.config import FusedMoEQuantConfig +from vllm.model_executor.layers.fused_moe.config import ( + FusedMoEParallelConfig, + FusedMoEQuantConfig, +) from vllm.model_executor.layers.fused_moe.utils import ( _resize_cache, count_expert_num_tokens, @@ -681,7 +683,7 @@ class FusedMoEModularKernel(torch.nn.Module): fused_experts: FusedMoEPermuteExpertsUnpermute, shared_experts: torch.nn.Module | None = None, shared_experts_stream: torch.cuda.Stream | None = None, - parallel_config: ParallelConfig | None = None, + moe_parallel_config: FusedMoEParallelConfig | None = None, ): super().__init__() self.prepare_finalize = prepare_finalize @@ -689,12 +691,15 @@ class FusedMoEModularKernel(torch.nn.Module): self.shared_experts = shared_experts self.shared_experts_stream = shared_experts_stream - # cache whether this worker is using DP+EP - if parallel_config is None: - parallel_config = get_current_vllm_config().parallel_config + # prefer an explicit FusedMoEParallelConfig when available (from + # FusedMoE layers / tests). + # if not provided, assume this kernel is + # running in a non-DP+EP context + self.moe_parallel_config: FusedMoEParallelConfig | None = moe_parallel_config self.is_dp_ep = ( - parallel_config.data_parallel_size > 1 - and parallel_config.enable_expert_parallel + moe_parallel_config is not None + and moe_parallel_config.dp_size > 1 + and moe_parallel_config.use_ep ) self._post_init_setup() diff --git a/vllm/model_executor/layers/quantization/compressed_tensors/compressed_tensors_moe.py b/vllm/model_executor/layers/quantization/compressed_tensors/compressed_tensors_moe.py index 5ad26f9318df3..18c2ab026b2ba 100644 --- a/vllm/model_executor/layers/quantization/compressed_tensors/compressed_tensors_moe.py +++ b/vllm/model_executor/layers/quantization/compressed_tensors/compressed_tensors_moe.py @@ -1266,9 +1266,6 @@ class CompressedTensorsW8A8Fp8MoEMethod(CompressedTensorsMoEMethod): ab_strides2=self.ab_strides2, c_strides1=self.c_strides1, c_strides2=self.ab_strides1_c_strides2, - parallel_config=getattr( - getattr(layer, "vllm_config", None), "parallel_config", None - ), ) else: diff --git a/vllm/model_executor/layers/quantization/utils/flashinfer_utils.py b/vllm/model_executor/layers/quantization/utils/flashinfer_utils.py index 09d0fe6a2f3ad..3d6e9cda87667 100644 --- a/vllm/model_executor/layers/quantization/utils/flashinfer_utils.py +++ b/vllm/model_executor/layers/quantization/utils/flashinfer_utils.py @@ -247,11 +247,6 @@ def flashinfer_cutlass_moe_fp8( assert quant_config is not None # Construct modular kernel with block-scale support when requested. - parallel_config = getattr( - getattr(layer, "vllm_config", None), - "parallel_config", - None, - ) fused_experts = mk.FusedMoEModularKernel( build_flashinfer_fp8_cutlass_moe_prepare_finalize( moe=moe, use_deepseek_fp8_block_scale=use_deepseek_fp8_block_scale @@ -262,7 +257,7 @@ def flashinfer_cutlass_moe_fp8( out_dtype=hidden_states.dtype, use_deepseek_fp8_block_scale=use_deepseek_fp8_block_scale, ), - parallel_config=parallel_config, + moe_parallel_config=layer.moe_parallel_config, ) return fused_experts(