From 39d28108f46ac320bccdf76e84640dbf2bd02bb4 Mon Sep 17 00:00:00 2001 From: Omer Ullman Argov <118735753+omera-nv@users.noreply.github.com> Date: Sun, 30 Nov 2025 18:02:40 +0200 Subject: [PATCH] [Feat] Support non-gated activations in NVFP4 modelopt path (#29004) --- tests/kernels/moe/test_flashinfer_moe.py | 24 +++++-- tests/kernels/moe/utils.py | 11 +++- tests/kernels/utils.py | 8 ++- vllm/model_executor/layers/fused_moe/layer.py | 12 +++- .../layers/quantization/modelopt.py | 65 ++++++++++++++++--- 5 files changed, 98 insertions(+), 22 deletions(-) diff --git a/tests/kernels/moe/test_flashinfer_moe.py b/tests/kernels/moe/test_flashinfer_moe.py index be3e36865d1a4..b2be03ecee2f1 100644 --- a/tests/kernels/moe/test_flashinfer_moe.py +++ b/tests/kernels/moe/test_flashinfer_moe.py @@ -16,11 +16,11 @@ from vllm.model_executor.layers.fused_moe.flashinfer_cutlass_moe import ( FlashInferExperts, is_valid_flashinfer_cutlass_fused_moe, ) +from vllm.model_executor.layers.fused_moe.flashinfer_cutlass_prepare_finalize import ( + create_flashinfer_prepare_finalize, +) from vllm.model_executor.layers.fused_moe.fused_moe import fused_topk from vllm.model_executor.layers.fused_moe.modular_kernel import FusedMoEModularKernel -from vllm.model_executor.layers.fused_moe.prepare_finalize import ( - MoEPrepareAndFinalizeNoEP, -) from vllm.platforms import current_platform from vllm.utils.flashinfer import has_flashinfer_cutlass_fused_moe @@ -48,9 +48,10 @@ MNK_FACTORS = [ @pytest.mark.parametrize("e", [40, 64, 256]) @pytest.mark.parametrize("topk", [1, 6, 8]) @pytest.mark.parametrize("dtype", [torch.bfloat16]) +@pytest.mark.parametrize("activation", ["silu_and_mul", "relu2"]) @torch.inference_mode() def test_flashinfer_fp4_moe_no_graph( - m: int, n: int, k: int, e: int, topk: int, dtype: torch.dtype + m: int, n: int, k: int, e: int, topk: int, dtype: torch.dtype, activation: str ): current_platform.seed_everything(7) with set_current_vllm_config( @@ -59,6 +60,7 @@ def test_flashinfer_fp4_moe_no_graph( a = torch.randn((m, k), device="cuda", dtype=dtype) / 10 quant_blocksize = 16 + is_gated_act = activation == "silu_and_mul" w1_q, w2_q, quant_config = make_test_quant_config( e, @@ -68,6 +70,7 @@ def test_flashinfer_fp4_moe_no_graph( quant_dtype="nvfp4", block_shape=None, per_act_token_quant=False, + make_gate=is_gated_act, ) score = torch.randn((m, e), device="cuda", dtype=dtype) @@ -76,16 +79,19 @@ def test_flashinfer_fp4_moe_no_graph( assert is_valid_flashinfer_cutlass_fused_moe(a, w1_q, w2_q) flashinfer_experts = FusedMoEModularKernel( - MoEPrepareAndFinalizeNoEP(), + create_flashinfer_prepare_finalize(use_dp=False, use_nvfp4=True), FlashInferExperts(out_dtype=dtype, quant_config=quant_config), ) + fi_activation = {"silu_and_mul": "silu", "relu2": "relu2_no_mul"}[activation] + flashinfer_output = flashinfer_experts( hidden_states=a, w1=w1_q, w2=w2_q, topk_weights=topk_weights, topk_ids=topk_ids, + activation=fi_activation, ) # Reference check: @@ -103,7 +109,9 @@ def test_flashinfer_fp4_moe_no_graph( block_size=quant_blocksize, ) - w1_d = torch.empty((e, 2 * n, k), device="cuda", dtype=dtype) + w1_d = torch.empty( + (e, (2 if is_gated_act else 1) * n, k), device="cuda", dtype=dtype + ) w2_d = torch.empty((e, k, n), device="cuda", dtype=dtype) for idx in range(0, e): @@ -124,7 +132,9 @@ def test_flashinfer_fp4_moe_no_graph( block_size=quant_blocksize, ) - torch_output = torch_moe(a_in_dtype, w1_d, w2_d, score, topk) + torch_output = torch_moe( + a_in_dtype, w1_d, w2_d, score, topk, activation=activation + ) torch.testing.assert_close( torch_output, flashinfer_output, atol=1e-1, rtol=1e-1 diff --git a/tests/kernels/moe/utils.py b/tests/kernels/moe/utils.py index c7e6c4240e853..f0c8c8033b8eb 100644 --- a/tests/kernels/moe/utils.py +++ b/tests/kernels/moe/utils.py @@ -264,13 +264,20 @@ def make_test_weights( quant_dtype: torch.dtype | str | None = None, block_shape: list[int] | None = None, per_out_ch_quant: bool = False, + make_gate: bool = True, ) -> tuple[ tuple[torch.Tensor, torch.Tensor, torch.Tensor | None, torch.Tensor | None], tuple[torch.Tensor, torch.Tensor, torch.Tensor | None, torch.Tensor | None], ]: return ( make_test_weight( - e, 2 * n, k, in_dtype, quant_dtype, block_shape, per_out_ch_quant + e, + (2 if make_gate else 1) * n, + k, + in_dtype, + quant_dtype, + block_shape, + per_out_ch_quant, ), make_test_weight(e, k, n, in_dtype, quant_dtype, block_shape, per_out_ch_quant), ) @@ -297,6 +304,7 @@ def make_test_quant_config( quant_dtype: torch.dtype | str | None = None, per_act_token_quant: bool = False, block_shape: list[int] | None = None, + make_gate: bool = True, ) -> tuple[torch.Tensor, torch.Tensor, FusedMoEQuantConfig]: (_, w1, w1_s, w1_gs), (_, w2, w2_s, w2_gs) = make_test_weights( e, @@ -306,6 +314,7 @@ def make_test_quant_config( quant_dtype, per_out_ch_quant=per_act_token_quant, block_shape=block_shape, + make_gate=make_gate, ) # Hacky/trivial scales for nvfp4. diff --git a/tests/kernels/utils.py b/tests/kernels/utils.py index 98646442391fe..72c79370d19c1 100644 --- a/tests/kernels/utils.py +++ b/tests/kernels/utils.py @@ -14,6 +14,7 @@ from torch._prims_common import TensorLikeType from tests.kernels.quant_utils import native_w8a8_block_matmul from vllm.attention.backends.abstract import AttentionType +from vllm.model_executor.custom_op import CustomOp from vllm.model_executor.layers.activation import SiluAndMul from vllm.model_executor.layers.fused_moe.utils import moe_kernel_quantize_input from vllm.utils.torch_utils import make_tensor_with_pad @@ -839,6 +840,7 @@ def torch_experts( per_act_token_quant=False, block_shape: list[int] | None = None, apply_router_weights_on_input: bool = False, + activation: str = "silu_and_mul", ) -> torch.Tensor: assert ( global_num_experts == -1 @@ -881,6 +883,8 @@ def torch_experts( f32 = torch.float32 + act = CustomOp.op_registry[activation] + for i in range(num_experts): mask = topk_ids == i if mask.sum(): @@ -888,7 +892,7 @@ def torch_experts( tmp1 = a[mask] @ w1[i].transpose(0, 1) if b_bias1 is not None: tmp1 = tmp1 + b_bias1[i].view(1, -1).to(tmp1.dtype) - tmp2 = SiluAndMul()(tmp1) + tmp2 = act()(tmp1) out[mask] = tmp2 @ w2[i].transpose(0, 1) if b_bias2 is not None: out[mask] = out[mask] + b_bias2[i].view(1, -1).to(tmp1.dtype) @@ -969,6 +973,7 @@ def torch_moe( b_bias2: torch.Tensor | None = None, global_num_experts: int = -1, expert_map: torch.Tensor | None = None, + activation: str = "silu_and_mul", ) -> torch.Tensor: score = torch.softmax(score, dim=-1, dtype=torch.float32) topk_weight, topk_ids = torch.topk(score, topk) @@ -982,6 +987,7 @@ def torch_moe( b_bias1, b_bias2, expert_map, + activation=activation, ) diff --git a/vllm/model_executor/layers/fused_moe/layer.py b/vllm/model_executor/layers/fused_moe/layer.py index e180b4f4ba23f..902a77987d61a 100644 --- a/vllm/model_executor/layers/fused_moe/layer.py +++ b/vllm/model_executor/layers/fused_moe/layer.py @@ -600,14 +600,20 @@ class FusedMoE(CustomOp): # Avoid circular import from vllm.model_executor.layers.quantization.modelopt import ( ModelOptFp8MoEMethod, + ModelOptNvFp4FusedMoE, ) if not isinstance( - self.quant_method, (UnquantizedFusedMoEMethod, ModelOptFp8MoEMethod) + self.quant_method, + ( + UnquantizedFusedMoEMethod, + ModelOptFp8MoEMethod, + ModelOptNvFp4FusedMoE, + ), ): raise NotImplementedError( "is_act_and_mul=False is supported only for unquantized " - "and ModelOpt FP8 moe for now" + ", ModelOpt FP8, and ModelOpt NvFp4 checkpoints" ) if not current_platform.is_cuda(): raise NotImplementedError( @@ -1277,7 +1283,7 @@ class FusedMoE(CustomOp): self._load_combined_w13_weight_scale( shard_dim=shard_dim, loaded_weight=loaded_weight, - param=param, + param=expert_data, tp_rank=self.tp_rank, ) return True if return_success else None diff --git a/vllm/model_executor/layers/quantization/modelopt.py b/vllm/model_executor/layers/quantization/modelopt.py index 709c86175477a..034e97a713cdd 100644 --- a/vllm/model_executor/layers/quantization/modelopt.py +++ b/vllm/model_executor/layers/quantization/modelopt.py @@ -1216,7 +1216,7 @@ class ModelOptNvFp4FusedMoE(FusedMoEMethodBase): w13_weight = ModelWeightParameter( data=torch.empty( num_experts, - 2 * intermediate_size_per_partition, + (2 if self.moe.is_act_and_mul else 1) * intermediate_size_per_partition, # 2 fp4 items are packed in the input dimension hidden_size // 2, dtype=weight_dtype, @@ -1245,7 +1245,7 @@ class ModelOptNvFp4FusedMoE(FusedMoEMethodBase): w13_weight_scale = ModelWeightParameter( data=torch.empty( num_experts, - 2 * intermediate_size_per_partition, + (2 if self.moe.is_act_and_mul else 1) * intermediate_size_per_partition, # 2 fp4 items are packed in the input dimension hidden_size // self.quant_config.group_size, dtype=weight_scale_dtype, @@ -1275,7 +1275,9 @@ class ModelOptNvFp4FusedMoE(FusedMoEMethodBase): ) w13_weight_scale_2 = PerTensorScaleParameter( - data=torch.empty(num_experts, 2, dtype=torch.float32), + data=torch.empty( + num_experts, 2 if self.moe.is_act_and_mul else 1, dtype=torch.float32 + ), weight_loader=weight_loader, ) layer.register_parameter("w13_weight_scale_2", w13_weight_scale_2) @@ -1296,7 +1298,11 @@ class ModelOptNvFp4FusedMoE(FusedMoEMethodBase): global_scale_num_experts = global_num_experts if use_global_sf else num_experts w13_input_scale = PerTensorScaleParameter( - data=torch.empty(global_scale_num_experts, 2, dtype=torch.float32), + data=torch.empty( + global_scale_num_experts, + 2 if self.moe.is_act_and_mul else 1, + dtype=torch.float32, + ), weight_loader=weight_loader, ) layer.register_parameter("w13_input_scale", w13_input_scale) @@ -1312,9 +1318,13 @@ class ModelOptNvFp4FusedMoE(FusedMoEMethodBase): gemm1_weight = layer.w13_weight.data gemm1_weight_scale = layer.w13_weight_scale.data - if self.allow_flashinfer and ( - self.flashinfer_moe_backend == FlashinferMoeBackend.CUTLASS - or self.flashinfer_moe_backend == FlashinferMoeBackend.TENSORRT_LLM + if ( + self.allow_flashinfer + and ( + self.flashinfer_moe_backend == FlashinferMoeBackend.CUTLASS + or self.flashinfer_moe_backend == FlashinferMoeBackend.TENSORRT_LLM + ) + and self.moe.is_act_and_mul ): gemm1_weight, gemm1_weight_scale = reorder_w1w3_to_w3w1( gemm1_weight, gemm1_weight_scale, dim=-2 @@ -1324,7 +1334,7 @@ class ModelOptNvFp4FusedMoE(FusedMoEMethodBase): layer.w13_weight_scale = Parameter(gemm1_weight_scale, requires_grad=False) # Common processing for w13_weight_scale_2 - if not torch.allclose( + if self.moe.is_act_and_mul and not torch.allclose( layer.w13_weight_scale_2[:, 0], layer.w13_weight_scale_2[:, 1] ): logger.warning_once( @@ -1437,11 +1447,39 @@ class ModelOptNvFp4FusedMoE(FusedMoEMethodBase): w13_blockscale_swizzled, requires_grad=False ) + w13_weight = layer.w13_weight + intermediate_size_pad = w13_blockscale_swizzled.size(1) - w13_weight.size(1) + if intermediate_size_pad: + # padding gated activations will require to split w1 and w3 + # and pad them individually + assert not self.moe.is_act_and_mul, ( + "The intermediate size required padding, " + "but padding is not implemented for gated activations" + ) + + layer.w13_weight = Parameter( + torch.nn.functional.pad( + w13_weight, (0, 0, 0, intermediate_size_pad) + ), + requires_grad=False, + ) + layer.w2_weight = Parameter( + torch.nn.functional.pad( + layer.w2_weight, (0, intermediate_size_pad // 2, 0, 0) + ), + requires_grad=False, + ) + layer.w2_weight_scale = Parameter( + torch.nn.functional.pad( + layer.w2_weight_scale, (0, intermediate_size_pad // 16) + ), + requires_grad=False, + ) + w2_blockscale_swizzled = swizzle_blockscale(layer.w2_weight_scale) layer.w2_weight_scale = Parameter( w2_blockscale_swizzled, requires_grad=False ) - layer.w2_weight = Parameter(layer.w2_weight.data, requires_grad=False) def get_fused_moe_quant_config( self, layer: torch.nn.Module @@ -1484,7 +1522,14 @@ class ModelOptNvFp4FusedMoE(FusedMoEMethodBase): logical_to_physical_map: torch.Tensor | None = None, logical_replica_count: torch.Tensor | None = None, ) -> torch.Tensor | tuple[torch.Tensor, torch.Tensor]: - assert activation == "silu", "Only SiLU activation is supported." + if not self.moe.is_act_and_mul: + assert ( + self.allow_flashinfer + and self.flashinfer_moe_backend == FlashinferMoeBackend.CUTLASS + ), ( + "Non-gated activations are only supported by the" + " flashinfer CUTLASS backend for modelopt checkpoints" + ) if ( self.allow_flashinfer