From b94faf9d50360217659f5605fce45a562dde6834 Mon Sep 17 00:00:00 2001 From: bnellnm <49004751+bnellnm@users.noreply.github.com> Date: Tue, 19 Aug 2025 14:00:51 -0400 Subject: [PATCH] [Bugfix] Fix accuracy issue when using flashinfer cutlass moe, TP=1 and modelopt. (#23125) Signed-off-by: Bill Nell Co-authored-by: Michael Goin --- .../fused_moe/flashinfer_cutlass_moe.py | 49 ++++++++++ vllm/model_executor/layers/fused_moe/layer.py | 2 + .../compressed_tensors_moe.py | 27 ++++++ .../layers/quantization/modelopt.py | 90 ++++++++++++------- 4 files changed, 134 insertions(+), 34 deletions(-) diff --git a/vllm/model_executor/layers/fused_moe/flashinfer_cutlass_moe.py b/vllm/model_executor/layers/fused_moe/flashinfer_cutlass_moe.py index 3fbe2a0bc69bb..6a9c28b53cd8b 100644 --- a/vllm/model_executor/layers/fused_moe/flashinfer_cutlass_moe.py +++ b/vllm/model_executor/layers/fused_moe/flashinfer_cutlass_moe.py @@ -7,6 +7,8 @@ import torch import vllm.model_executor.layers.fused_moe.modular_kernel as mk from vllm.logger import init_logger from vllm.model_executor.layers.fused_moe.config import FusedMoEQuantConfig +from vllm.model_executor.layers.fused_moe.flashinfer_cutlass_prepare_finalize import ( # noqa: E501 + FlashInferCutlassMoEPrepareAndFinalize) from vllm.model_executor.layers.fused_moe.topk_weight_and_reduce import ( TopKWeightAndReduceNoOP) from vllm.utils.flashinfer import (flashinfer_cutlass_fused_moe, @@ -181,3 +183,50 @@ class FlashInferExperts(mk.FusedMoEPermuteExpertsUnpermute): ep_rank=self.ep_rank, output=output, ) + + +def flashinfer_cutlass_moe_fp4( + hidden_states: torch.Tensor, + w1: torch.Tensor, + w2: torch.Tensor, + topk_weights: torch.Tensor, + topk_ids: torch.Tensor, + w1_scale: torch.Tensor, + w2_scale: torch.Tensor, + g1_alphas: torch.Tensor, + g2_alphas: torch.Tensor, + a1_gscale: torch.Tensor, + a2_gscale: torch.Tensor, + inplace: bool = False, + activation: str = "silu", + global_num_experts: int = -1, + expert_map: Optional[torch.Tensor] = None, + apply_router_weight_on_input: bool = False, +) -> torch.Tensor: + + fused_experts = mk.FusedMoEModularKernel( + FlashInferCutlassMoEPrepareAndFinalize(use_dp=False, + a1_gscale=a1_gscale), + FlashInferExperts( + g1_alphas=g1_alphas, + g2_alphas=g2_alphas, + a1_gscale=a1_gscale, + a2_gscale=a2_gscale, + out_dtype=hidden_states.dtype, + quant_dtype="nvfp4", + )) + + return fused_experts( + hidden_states=hidden_states, + w1=w1, + w2=w2, + topk_weights=topk_weights, + topk_ids=topk_ids, + inplace=inplace, + activation=activation, + global_num_experts=global_num_experts, + expert_map=expert_map, + w1_scale=w1_scale, + w2_scale=w2_scale, + apply_router_weight_on_input=apply_router_weight_on_input, + ) diff --git a/vllm/model_executor/layers/fused_moe/layer.py b/vllm/model_executor/layers/fused_moe/layer.py index 4924f1fadb3b1..aa8ceda1bb25a 100644 --- a/vllm/model_executor/layers/fused_moe/layer.py +++ b/vllm/model_executor/layers/fused_moe/layer.py @@ -198,6 +198,8 @@ class FusedMoEMethodBase(QuantizeMethodBase): else: return None + # Note: init_prepare_finalize should only be called by + # prepare_communication_buffer_for_model. def init_prepare_finalize(self): assert self.moe is not None prepare_finalize = self.maybe_make_prepare_finalize(self.moe) diff --git a/vllm/model_executor/layers/quantization/compressed_tensors/compressed_tensors_moe.py b/vllm/model_executor/layers/quantization/compressed_tensors/compressed_tensors_moe.py index 42c43cbc03e57..8ca8249e694ea 100644 --- a/vllm/model_executor/layers/quantization/compressed_tensors/compressed_tensors_moe.py +++ b/vllm/model_executor/layers/quantization/compressed_tensors/compressed_tensors_moe.py @@ -388,6 +388,33 @@ class CompressedTensorsW4A4MoeMethod(CompressedTensorsMoEMethod): apply_router_weight_on_input=apply_router_weight_on_input, ) + elif self.allow_flashinfer: + from vllm.model_executor.layers.fused_moe.flashinfer_cutlass_moe import ( # noqa: E501 + flashinfer_cutlass_moe_fp4) + + assert is_valid_flashinfer_cutlass_fused_moe( + x, layer.w13_weight, layer.w2_weight), ( + "Flashinfer CUTLASS Fused MoE not applicable!") + + return flashinfer_cutlass_moe_fp4( + hidden_states=x, + w1=layer.w13_weight, + w2=layer.w2_weight, + topk_weights=topk_weights, + topk_ids=topk_ids, + inplace=False, # TODO(shuw): fix later, now output is high prec + activation=activation, + global_num_experts=global_num_experts, + expert_map=expert_map, + w1_scale=layer.w13_blockscale_swizzled, + w2_scale=layer.w2_blockscale_swizzled, + g1_alphas=layer.g1_alphas, + g2_alphas=layer.g2_alphas, + a1_gscale=layer.w13_input_scale_quant, + a2_gscale=layer.w2_input_scale_quant, + apply_router_weight_on_input=apply_router_weight_on_input, + ) + assert expert_map is None, ("Expert Parallelism / expert_map " "is currently not supported for " "CompressedTensorsW4A4MoeMethod.") diff --git a/vllm/model_executor/layers/quantization/modelopt.py b/vllm/model_executor/layers/quantization/modelopt.py index e0f462b36976f..28f16d1088346 100644 --- a/vllm/model_executor/layers/quantization/modelopt.py +++ b/vllm/model_executor/layers/quantization/modelopt.py @@ -966,22 +966,21 @@ class ModelOptNvFp4FusedMoE(FusedMoEMethodBase): f"Unknown flashinfer moe backend: {flashinfer_moe_backend}" f" expected one of {allowed_backends}") - self.fused_experts: Optional[ - mk.FusedMoEModularKernel] = None # type: ignore[assignment] - def maybe_make_prepare_finalize( self, moe: FusedMoEConfig, ) -> Optional[mk.FusedMoEPrepareAndFinalize]: - if not self.allow_flashinfer: - return super().maybe_make_prepare_finalize(moe) + if (self.allow_flashinfer and self.flashinfer_moe_backend + == FlashinferMoeBackend.CUTLASS): + prepare_finalize = ( + build_flashinfer_fp4_cutlass_moe_prepare_finalize( + moe, + a1_gscale=self.layer.w13_input_scale_quant, + )) + logger.debug_once("%s", prepare_finalize.__class__.__name__) + return prepare_finalize - prepare_finalize = build_flashinfer_fp4_cutlass_moe_prepare_finalize( - moe, - a1_gscale=self.layer.w13_input_scale_quant, - ) - logger.debug_once("%s", prepare_finalize.__class__.__name__) - return prepare_finalize + return super().maybe_make_prepare_finalize(moe) def select_gemm_impl( self, @@ -1409,7 +1408,52 @@ class ModelOptNvFp4FusedMoE(FusedMoEMethodBase): global_num_experts=global_num_experts, expert_map=expert_map) - if self.fused_experts is None: + if self.fused_experts is not None: + assert self.allow_flashinfer and \ + self.flashinfer_moe_backend == FlashinferMoeBackend.CUTLASS + + assert is_valid_flashinfer_cutlass_fused_moe( + x, layer.w13_weight, layer.w2_weight), ( + "Flashinfer CUTLASS Fused MoE not applicable!") + + out = self.fused_experts( + hidden_states=x, + w1=layer.w13_weight, + w2=layer.w2_weight, + topk_weights=topk_weights, + topk_ids=topk_ids, + inplace=False, # TODO(shuw): fix later, now output is high prec + activation=activation, + global_num_experts=global_num_experts, + expert_map=expert_map, + w1_scale=layer.w13_blockscale_swizzled, + w2_scale=layer.w2_blockscale_swizzled, + apply_router_weight_on_input=apply_router_weight_on_input, + ) + elif (self.allow_flashinfer + and self.flashinfer_moe_backend == FlashinferMoeBackend.CUTLASS): + from vllm.model_executor.layers.fused_moe.flashinfer_cutlass_moe import ( # noqa: E501 + flashinfer_cutlass_moe_fp4) + + out = flashinfer_cutlass_moe_fp4( + hidden_states=x, + w1=layer.w13_weight, + w2=layer.w2_weight, + topk_weights=topk_weights, + topk_ids=topk_ids, + w1_scale=layer.w13_blockscale_swizzled, + w2_scale=layer.w2_blockscale_swizzled, + g1_alphas=layer.g1_alphas, + g2_alphas=layer.g2_alphas, + a1_gscale=layer.w13_input_scale_quant, + a2_gscale=layer.w2_input_scale_quant, + inplace=False, # TODO(shuw): fix later, now output is high prec + activation=activation, + global_num_experts=global_num_experts, + expert_map=expert_map, + apply_router_weight_on_input=apply_router_weight_on_input, + ) + else: # If no modular kernel is provided, use cutlass_moe_fp4 for TP case # only (no EP). from vllm.model_executor.layers.fused_moe.cutlass_moe import ( @@ -1432,27 +1476,5 @@ class ModelOptNvFp4FusedMoE(FusedMoEMethodBase): e=layer.w13_weight.shape[0], expert_map=expert_map, apply_router_weight_on_input=apply_router_weight_on_input) - else: - assert self.allow_flashinfer and \ - self.flashinfer_moe_backend == FlashinferMoeBackend.CUTLASS - - assert is_valid_flashinfer_cutlass_fused_moe( - x, layer.w13_weight, layer.w2_weight), ( - "Flashinfer CUTLASS Fused MoE not applicable!") - - out = self.fused_experts( - hidden_states=x, - w1=layer.w13_weight, - w2=layer.w2_weight, - topk_weights=topk_weights, - topk_ids=topk_ids, - inplace=False, # TODO(shuw): fix later, now output is high prec - activation=activation, - global_num_experts=global_num_experts, - expert_map=expert_map, - w1_scale=layer.w13_blockscale_swizzled, - w2_scale=layer.w2_blockscale_swizzled, - apply_router_weight_on_input=apply_router_weight_on_input, - ) return out