From b4cef5e6c7bd9ec3dbb951fd913ee36dbadf598d Mon Sep 17 00:00:00 2001 From: amirkl94 <203507526+amirkl94@users.noreply.github.com> Date: Fri, 15 Aug 2025 09:19:31 +0300 Subject: [PATCH] refactor: Change scaling factors calculation for flashinfer FusedMoE (#22812) Signed-off-by: Amir Klein <203507526+amirkl94@users.noreply.github.com> Co-authored-by: Michael Goin --- .../layers/fused_moe/fused_moe.py | 29 +++++------- .../model_executor/layers/quantization/fp8.py | 5 +- .../layers/quantization/modelopt.py | 5 +- .../quantization/utils/flashinfer_utils.py | 46 +++++++++++++++++-- 4 files changed, 60 insertions(+), 25 deletions(-) diff --git a/vllm/model_executor/layers/fused_moe/fused_moe.py b/vllm/model_executor/layers/fused_moe/fused_moe.py index 98087a35e15c7..1c497fa5521b9 100644 --- a/vllm/model_executor/layers/fused_moe/fused_moe.py +++ b/vllm/model_executor/layers/fused_moe/fused_moe.py @@ -1189,10 +1189,10 @@ def flashinfer_fused_moe_per_tensor_scale_fp8( hidden_states: torch.Tensor, input_scale: torch.Tensor, gemm1_weights: torch.Tensor, - gemm1_weights_scale: torch.Tensor, - activation_scale: torch.Tensor, gemm2_weights: torch.Tensor, - gemm2_weights_scale: torch.Tensor, + output1_scales_scalar: torch.Tensor, + output1_scales_gate_scalar: torch.Tensor, + output2_scales_scalar: torch.Tensor, num_experts: int, top_k: int, num_expert_group: Optional[int], @@ -1206,17 +1206,12 @@ def flashinfer_fused_moe_per_tensor_scale_fp8( num_expert_group = num_expert_group if num_expert_group is not None else 0 topk_group = topk_group if topk_group is not None else 0 - quant_hidden_states, input_scale = moe_kernel_quantize_input( + quant_hidden_states, _ = moe_kernel_quantize_input( hidden_states, input_scale, quant_dtype=torch.float8_e4m3fn, per_act_token_quant=False) - output1_scales_scalar = gemm1_weights_scale * input_scale * ( - 1.0 / activation_scale) - output1_scales_gate_scalar = gemm1_weights_scale * input_scale - output2_scales_scalar = activation_scale * gemm2_weights_scale - from vllm.utils.flashinfer import ( flashinfer_trtllm_fp8_per_tensor_scale_moe) return flashinfer_trtllm_fp8_per_tensor_scale_moe( @@ -1244,24 +1239,24 @@ def flashinfer_fused_moe_per_tensor_scale_fp8( def flashinfer_fused_moe_per_tensor_scale_fp8_fake( routing_logits: torch.Tensor, - routing_bias: torch.Tensor, + routing_bias: Optional[torch.Tensor], hidden_states: torch.Tensor, + input_scale: torch.Tensor, gemm1_weights: torch.Tensor, + gemm2_weights: torch.Tensor, output1_scales_scalar: torch.Tensor, output1_scales_gate_scalar: torch.Tensor, - gemm2_weights: torch.Tensor, output2_scales_scalar: torch.Tensor, num_experts: int, top_k: int, - num_expert_group: int, - topk_group: int, + num_expert_group: Optional[int], + topk_group: Optional[int], intermediate_size: int, local_expert_offset: int, local_num_experts: int, - routed_scaling_factor: float = 1.0, - use_routing_scales_on_input: bool = False, - tile_tokens_dim: int = 8, - routing_method_type: int = 0) -> torch.Tensor: + use_routing_scales_on_input: bool, + routing_method_type: int, + routed_scaling_factor: float = 1.0) -> torch.Tensor: pass diff --git a/vllm/model_executor/layers/quantization/fp8.py b/vllm/model_executor/layers/quantization/fp8.py index 5e107c799b9f0..dbd5234286952 100644 --- a/vllm/model_executor/layers/quantization/fp8.py +++ b/vllm/model_executor/layers/quantization/fp8.py @@ -24,8 +24,8 @@ from vllm.model_executor.layers.quantization.base_config import ( QuantizationConfig, QuantizeMethodBase) from vllm.model_executor.layers.quantization.kv_cache import BaseKVCacheMethod from vllm.model_executor.layers.quantization.utils.flashinfer_utils import ( - apply_flashinfer_per_tensor_scale_fp8, rotate_flashinfer_fp8_moe_weights, - swap_w13_to_w31) + apply_flashinfer_per_tensor_scale_fp8, register_moe_scaling_factors, + rotate_flashinfer_fp8_moe_weights, swap_w13_to_w31) from vllm.model_executor.layers.quantization.utils.fp8_utils import ( get_col_major_tma_aligned_tensor, requant_weight_ue8m0_inplace) from vllm.model_executor.layers.quantization.utils.marlin_utils_fp8 import ( @@ -694,6 +694,7 @@ class Fp8MoEMethod(FusedMoEMethodBase): w2_weight = layer.w2_weight.data w2_weight_scale_inv = layer.w2_weight_scale_inv.data if not self.block_quant: + register_moe_scaling_factors(layer) rotate_flashinfer_fp8_moe_weights(w13_weight, w2_weight) else: w13_weight = layer.w13_weight.data diff --git a/vllm/model_executor/layers/quantization/modelopt.py b/vllm/model_executor/layers/quantization/modelopt.py index 8f9ca73bc505d..22fbbab00e919 100644 --- a/vllm/model_executor/layers/quantization/modelopt.py +++ b/vllm/model_executor/layers/quantization/modelopt.py @@ -25,8 +25,8 @@ from vllm.model_executor.layers.quantization.utils.flashinfer_fp4_moe import ( build_flashinfer_fp4_cutlass_moe_kernel, flashinfer_fp4_cutlass_moe_forward, reorder_w1w3_to_w3w1) from vllm.model_executor.layers.quantization.utils.flashinfer_utils import ( - apply_flashinfer_per_tensor_scale_fp8, rotate_flashinfer_fp8_moe_weights, - swap_w13_to_w31) + apply_flashinfer_per_tensor_scale_fp8, register_moe_scaling_factors, + rotate_flashinfer_fp8_moe_weights, swap_w13_to_w31) from vllm.model_executor.layers.quantization.utils.marlin_utils_fp4 import ( apply_fp4_marlin_linear, is_fp4_marlin_supported, prepare_fp4_layer_for_marlin, prepare_moe_fp4_layer_for_marlin) @@ -430,6 +430,7 @@ class ModelOptFp8MoEMethod(FusedMoEMethodBase): layer.w13_weight.data = swap_w13_to_w31(layer.w13_weight.data) rotate_flashinfer_fp8_moe_weights(layer.w13_weight, layer.w2_weight) + register_moe_scaling_factors(layer) def apply( self, diff --git a/vllm/model_executor/layers/quantization/utils/flashinfer_utils.py b/vllm/model_executor/layers/quantization/utils/flashinfer_utils.py index 9fb194767e4a4..278ee5232f47e 100644 --- a/vllm/model_executor/layers/quantization/utils/flashinfer_utils.py +++ b/vllm/model_executor/layers/quantization/utils/flashinfer_utils.py @@ -82,6 +82,12 @@ def apply_flashinfer_per_tensor_scale_fp8( apply_router_weight_on_input: bool, ) -> torch.Tensor: from flashinfer.fused_moe import RoutingMethodType + assert layer.output1_scales_scalar is not None, ( + "Expected output1_scales_scalar to be initialized") + assert layer.output1_scales_scalar is not None, ( + "Expected output1_scales_gate_scalar to be initialized") + assert layer.output1_scales_scalar is not None, ( + "Expected output2_scales_scalar to be initialized") from vllm.model_executor.models.llama4 import Llama4MoE assert layer.custom_routing_function == Llama4MoE.custom_routing_function, \ @@ -92,10 +98,10 @@ def apply_flashinfer_per_tensor_scale_fp8( hidden_states=hidden_states, input_scale=layer.w13_input_scale, gemm1_weights=layer.w13_weight, - gemm1_weights_scale=layer.w13_weight_scale, gemm2_weights=layer.w2_weight, - gemm2_weights_scale=layer.w2_weight_scale, - activation_scale=layer.w2_input_scale, + output1_scales_scalar=layer.output1_scales_scalar, + output1_scales_gate_scalar=layer.output1_scales_gate_scalar, + output2_scales_scalar=layer.output2_scales_scalar, num_experts=global_num_experts, top_k=top_k, num_expert_group=num_expert_group, @@ -105,4 +111,36 @@ def apply_flashinfer_per_tensor_scale_fp8( local_num_experts=layer.local_num_experts, use_routing_scales_on_input=apply_router_weight_on_input, routing_method_type=RoutingMethodType.Llama4, - ) \ No newline at end of file + ) + + +def get_moe_scaling_factors( + input_scale: torch.Tensor, + gemm1_weights_scale: torch.Tensor, + activation_scale: torch.Tensor, + gemm2_weights_scale: torch.Tensor, +) -> tuple[torch.Tensor, torch.Tensor, torch.Tensor]: + output1_scales_scalar = gemm1_weights_scale * input_scale * ( + 1.0 / activation_scale) + output1_scales_gate_scalar = gemm1_weights_scale * input_scale + output2_scales_scalar = activation_scale * gemm2_weights_scale + + return output1_scales_scalar, output1_scales_gate_scalar, \ + output2_scales_scalar + + +def register_moe_scaling_factors(layer: torch.nn.Module) -> None: + output1_scales, output1_gate_scales, output2_scales = \ + get_moe_scaling_factors( + layer.w13_input_scale, layer.w13_weight_scale, + layer.w2_input_scale, layer.w2_weight_scale + ) + layer.register_parameter( + 'output1_scales_scalar', + torch.nn.Parameter(output1_scales, requires_grad=False)) + layer.register_parameter( + 'output1_scales_gate_scalar', + torch.nn.Parameter(output1_gate_scales, requires_grad=False)) + layer.register_parameter( + 'output2_scales_scalar', + torch.nn.Parameter(output2_scales, requires_grad=False))