From 03ee48111de7372a1231872f26262e7c46ab1c83 Mon Sep 17 00:00:00 2001 From: amirkl94 <203507526+amirkl94@users.noreply.github.com> Date: Sun, 16 Nov 2025 20:39:44 +0200 Subject: [PATCH] Feature: Support Relu2 in FusedMoE fp8 cutlass path (#27261) --- tests/kernels/moe/test_flashinfer.py | 18 +++++++--- .../fused_moe/flashinfer_cutlass_moe.py | 11 +++++-- .../layers/quantization/modelopt.py | 33 +++++++++++-------- 3 files changed, 42 insertions(+), 20 deletions(-) diff --git a/tests/kernels/moe/test_flashinfer.py b/tests/kernels/moe/test_flashinfer.py index 3a681d4603f8..218df4a2632c 100644 --- a/tests/kernels/moe/test_flashinfer.py +++ b/tests/kernels/moe/test_flashinfer.py @@ -77,10 +77,14 @@ class TestData: @staticmethod def make_moe_tensors_8bit( - m: int, k: int, n: int, e: int, reorder: bool + m: int, k: int, n: int, e: int, reorder: bool, activation: str = "silu" ) -> "TestData": + is_gated = activation != "relu2_no_mul" + hidden_states = torch.randn((m, k), device="cuda", dtype=torch.bfloat16) / 10 - w13 = torch.randn((e, 2 * n, k), device="cuda", dtype=torch.bfloat16) + w13 = torch.randn( + (e, (2 * n) if is_gated else n, k), device="cuda", dtype=torch.bfloat16 + ) w2 = torch.randn((e, k, n), device="cuda", dtype=torch.bfloat16) # Scale to fp8 @@ -190,18 +194,22 @@ def test_flashinfer_per_tensor_moe_fp8_no_graph( @pytest.mark.parametrize("m,n,k", MNK_FACTORS) @pytest.mark.parametrize("e", NUM_EXPERTS) @pytest.mark.parametrize("topk", TOP_KS) +@pytest.mark.parametrize("activation", ["silu", "relu2_no_mul"]) def test_flashinfer_cutlass_moe_fp8_no_graph( m: int, n: int, k: int, e: int, topk: int, + activation: str, monkeypatch, ): current_platform.seed_everything(7) monkeypatch.setenv("VLLM_FUSED_MOE_CHUNK_SIZE", "8192") with set_current_vllm_config(vllm_config): - td = TestData.make_moe_tensors_8bit(m, k, n, e, reorder=False) + td = TestData.make_moe_tensors_8bit( + m, k, n, e, reorder=False, activation=activation + ) score = torch.randn((m, e), device="cuda", dtype=torch.bfloat16) topk_weights, topk_ids, _ = FusedMoE.select_experts( @@ -233,7 +241,7 @@ def test_flashinfer_cutlass_moe_fp8_no_graph( topk_weights=topk_weights, topk_ids=topk_ids, inplace=False, - activation="silu", + activation=activation, global_num_experts=e, expert_map=None, apply_router_weight_on_input=True, @@ -253,7 +261,7 @@ def test_flashinfer_cutlass_moe_fp8_no_graph( td.layer, topk_weights, topk_ids, - activation="silu", + activation=activation, global_num_experts=e, expert_map=None, apply_router_weight_on_input=True, diff --git a/vllm/model_executor/layers/fused_moe/flashinfer_cutlass_moe.py b/vllm/model_executor/layers/fused_moe/flashinfer_cutlass_moe.py index 943695f921ad..f864634c6617 100644 --- a/vllm/model_executor/layers/fused_moe/flashinfer_cutlass_moe.py +++ b/vllm/model_executor/layers/fused_moe/flashinfer_cutlass_moe.py @@ -148,8 +148,14 @@ class FlashInferExperts(mk.FusedMoEPermuteExpertsUnpermute): expert_tokens_meta: mk.ExpertTokensMetadata | None, apply_router_weight_on_input: bool | None, ): - assert activation == "silu", ( - "Only activation silu is supported in FlashInferExperts" + from flashinfer.fused_moe.core import ActivationType + + activation_str_to_value_map = { + "silu": ActivationType.Swiglu, # This is the default + "relu2_no_mul": ActivationType.Relu2, + } + assert activation in activation_str_to_value_map, ( + f"{activation=} missing from {activation_str_to_value_map.keys()=}" ) # Select quantization metadata based on FP8 format/path @@ -215,6 +221,7 @@ class FlashInferExperts(mk.FusedMoEPermuteExpertsUnpermute): ep_size=self.ep_size, ep_rank=self.ep_rank, output=output, + activation_type=activation_str_to_value_map[activation], # Informs FlashInfer to use the block-scale decoding path when True use_deepseek_fp8_block_scale=self.use_deepseek_fp8_block_scale, ) diff --git a/vllm/model_executor/layers/quantization/modelopt.py b/vllm/model_executor/layers/quantization/modelopt.py index e14753c60c48..cf6325eb85df 100644 --- a/vllm/model_executor/layers/quantization/modelopt.py +++ b/vllm/model_executor/layers/quantization/modelopt.py @@ -354,12 +354,18 @@ class ModelOptFp8MoEMethod(FusedMoEMethodBase): self.cutlass_fp8_supported = cutlass_fp8_supported() self.flashinfer_moe_backend: FlashinferMoeBackend | None = None - if ( - envs.VLLM_USE_FLASHINFER_MOE_FP8 - and has_flashinfer_moe() - and self.moe.is_act_and_mul - ): + if envs.VLLM_USE_FLASHINFER_MOE_FP8 and has_flashinfer_moe(): self.flashinfer_moe_backend = get_flashinfer_moe_backend() + if ( + self.flashinfer_moe_backend == FlashinferMoeBackend.TENSORRT_LLM + and not self.moe.is_act_and_mul + ): + logger.info_once( + "Non-gated MoE is not supported for min-latency mode," + "falling back to high-throughput mode" + ) + self.flashinfer_moe_backend = FlashinferMoeBackend.CUTLASS + logger.info_once( f"Using FlashInfer {self.flashinfer_moe_backend.value} kernels" ) @@ -557,10 +563,11 @@ class ModelOptFp8MoEMethod(FusedMoEMethodBase): ) if self.flashinfer_moe_backend is not None: - layer.w13_weight.data = swap_w13_to_w31(layer.w13_weight.data) - register_moe_scaling_factors(layer) + if self.moe.is_act_and_mul: + layer.w13_weight.data = swap_w13_to_w31(layer.w13_weight.data) if self.flashinfer_moe_backend == FlashinferMoeBackend.TENSORRT_LLM: rotate_flashinfer_fp8_moe_weights(layer.w13_weight, layer.w2_weight) + register_moe_scaling_factors(layer) def get_fused_moe_quant_config( self, layer: torch.nn.Module @@ -570,13 +577,13 @@ class ModelOptFp8MoEMethod(FusedMoEMethodBase): return fp8_w8a8_moe_quant_config( w1_scale=layer.w13_weight_scale, - g1_alphas=(layer.w13_weight_scale * layer.w13_input_scale).squeeze(), + g1_alphas=layer.output1_scales_gate_scalar.squeeze(), w2_scale=layer.w2_weight_scale, - g2_alphas=(layer.w2_weight_scale * layer.w2_input_scale).squeeze(), + g2_alphas=layer.output2_scales_scalar.squeeze(), a1_scale=layer.w13_input_scale, a1_gscale=layer.w13_input_scale, a2_scale=layer.w2_input_scale, - a2_gscale=1.0 / layer.w2_input_scale, + a2_gscale=layer.w2_input_scale_inv, per_act_token_quant=False, ) @@ -642,9 +649,9 @@ class ModelOptFp8MoEMethod(FusedMoEMethodBase): ) if self.flashinfer_moe_backend == FlashinferMoeBackend.CUTLASS: - assert not renormalize - assert activation == "silu", ( - f"Expected 'silu' activation but got {activation}" + assert activation in ("silu", "relu2_no_mul"), ( + "Expected activation to be in ('silu', 'relu2_no_mul')," + f"but got {activation}" ) return flashinfer_cutlass_moe_fp8( x,