diff --git a/vllm/model_executor/layers/quantization/compressed_tensors/compressed_tensors_moe.py b/vllm/model_executor/layers/quantization/compressed_tensors/compressed_tensors_moe.py index fa254030a271a..ad547dd409822 100644 --- a/vllm/model_executor/layers/quantization/compressed_tensors/compressed_tensors_moe.py +++ b/vllm/model_executor/layers/quantization/compressed_tensors/compressed_tensors_moe.py @@ -8,6 +8,7 @@ from enum import Enum import torch from compressed_tensors import CompressionFormat from compressed_tensors.quantization import ActivationOrdering, QuantizationStrategy +from torch.nn.parameter import Parameter import vllm.envs as envs import vllm.model_executor.layers.fused_moe.modular_kernel as mk @@ -50,9 +51,15 @@ from vllm.model_executor.layers.quantization.compressed_tensors.utils import ( from vllm.model_executor.layers.quantization.utils import replace_parameter from vllm.model_executor.layers.quantization.utils.flashinfer_fp4_moe import ( build_flashinfer_fp4_cutlass_moe_prepare_finalize, + flashinfer_trtllm_fp4_moe, + prepare_static_weights_for_trtllm_fp4_moe, reorder_w1w3_to_w3w1, select_nvfp4_gemm_impl, ) +from vllm.model_executor.layers.quantization.utils.flashinfer_utils import ( + FlashinferMoeBackend, + get_flashinfer_moe_backend, +) from vllm.model_executor.layers.quantization.utils.fp8_utils import ( expert_weight_is_col_major, requant_weight_ue8m0_inplace, @@ -193,6 +200,13 @@ class CompressedTensorsW4A4MoeMethod(CompressedTensorsMoEMethod): self.allow_flashinfer = _nvfp4.allow_flashinfer self.use_marlin = _nvfp4.use_marlin self.group_size = 16 + self.flashinfer_moe_backend = None + if self.allow_flashinfer: + self.flashinfer_moe_backend = get_flashinfer_moe_backend() + logger.info_once( + f"Using FlashInfer {self.flashinfer_moe_backend.value} kernels" + " for CompressedTensorsW4A4MoeMethod." + ) def create_weights( self, @@ -344,21 +358,20 @@ class CompressedTensorsW4A4MoeMethod(CompressedTensorsMoEMethod): if self.use_marlin: prepare_moe_fp4_layer_for_marlin(layer) return - - # swizzle weight scales - layer.w13_weight_scale = torch.nn.Parameter( - swizzle_blockscale(layer.w13_weight_scale), requires_grad=False - ) - - layer.w2_weight_scale = torch.nn.Parameter( - swizzle_blockscale(layer.w2_weight_scale), requires_grad=False - ) - # w13 - w13_input_global_scale = layer.w13_input_global_scale.max(dim=1).values.to( - torch.float32 - ) - + if ( + self.allow_flashinfer + and self.flashinfer_moe_backend == FlashinferMoeBackend.TENSORRT_LLM + ): + w13_input_global_scale = ( + layer.w13_input_global_scale.min() + .to(torch.float32) + .expand(layer.num_experts) + ) + else: + w13_input_global_scale = layer.w13_input_global_scale.min(dim=1).values.to( + torch.float32 + ) layer.g1_alphas = torch.nn.Parameter( ((1 / w13_input_global_scale) * layer.w13_weight_scale_2), requires_grad=False, @@ -369,22 +382,92 @@ class CompressedTensorsW4A4MoeMethod(CompressedTensorsMoEMethod): ) # w2 + if ( + self.allow_flashinfer + and self.flashinfer_moe_backend == FlashinferMoeBackend.TENSORRT_LLM + ): + w2_input_global_scale = ( + layer.w2_input_global_scale.min() + .to(torch.float32) + .expand(layer.num_experts) + ) + else: + w2_input_global_scale = layer.w2_input_global_scale + layer.g2_alphas = torch.nn.Parameter( - ((1 / layer.w2_input_global_scale) * layer.w2_weight_scale_2).to( - torch.float32 - ), + ((1 / w2_input_global_scale) * layer.w2_weight_scale_2).to(torch.float32), requires_grad=False, ) layer.w2_input_scale_quant = torch.nn.Parameter( - (layer.w2_input_global_scale), requires_grad=False + (w2_input_global_scale), requires_grad=False ) + # TensorRT-LLM specific processing + if ( + self.allow_flashinfer + and self.flashinfer_moe_backend == FlashinferMoeBackend.TENSORRT_LLM + ): + # Prepare static weights for TRT-LLM kernel + # alternate: prepare_static_weight_layouts_for_trtllm_moe + ( + gemm1_weights_fp4_shuffled, + gemm1_scales_fp4_shuffled, + gemm2_weights_fp4_shuffled, + gemm2_scales_fp4_shuffled, + ) = prepare_static_weights_for_trtllm_fp4_moe( + layer.w13_weight, + layer.w2_weight, + layer.w13_weight_scale, + layer.w2_weight_scale, + layer.w2_weight.size(-2), # hidden_size + layer.w13_weight.size(-2) // 2, # intermediate_size + layer.w13_weight.size(0), # num_experts + ) + logger.debug_once("Finished shuffling weights for TRT-LLM MOE") + + layer.gemm1_weights_fp4_shuffled = Parameter( + gemm1_weights_fp4_shuffled, requires_grad=False + ) + layer.gemm2_weights_fp4_shuffled = Parameter( + gemm2_weights_fp4_shuffled, requires_grad=False + ) + layer.gemm1_scales_fp4_shuffled = Parameter( + gemm1_scales_fp4_shuffled, requires_grad=False + ) + layer.gemm2_scales_fp4_shuffled = Parameter( + gemm2_scales_fp4_shuffled, requires_grad=False + ) + + # Additional parameter needed for TRT-LLM + layer.g1_scale_c = Parameter( + (layer.w2_input_scale_quant * layer.g1_alphas).to(torch.float32), + requires_grad=False, + ) + + # Clean up weights that won't be used by TRT-LLM + del layer.w2_weight + del layer.w2_weight_scale + del layer.w13_weight + del layer.w13_weight_scale + else: + # swizzle weight scales + layer.w13_weight_scale = torch.nn.Parameter( + swizzle_blockscale(layer.w13_weight_scale), requires_grad=False + ) + + layer.w2_weight_scale = torch.nn.Parameter( + swizzle_blockscale(layer.w2_weight_scale), requires_grad=False + ) + def maybe_make_prepare_finalize( self, routing_tables: tuple[torch.Tensor, torch.Tensor, torch.Tensor] | None = None, ) -> mk.FusedMoEPrepareAndFinalize | None: - if self.use_marlin: + if self.use_marlin or ( + self.allow_flashinfer + and self.flashinfer_moe_backend == FlashinferMoeBackend.TENSORRT_LLM + ): return None elif not self.allow_flashinfer: return super().maybe_make_prepare_finalize(routing_tables) @@ -411,7 +494,10 @@ class CompressedTensorsW4A4MoeMethod(CompressedTensorsMoEMethod): def get_fused_moe_quant_config( self, layer: torch.nn.Module ) -> FusedMoEQuantConfig | None: - if self.use_marlin: + if ( + self.use_marlin + or self.flashinfer_moe_backend == FlashinferMoeBackend.TENSORRT_LLM + ): return None return nvfp4_moe_quant_config( @@ -452,6 +538,22 @@ class CompressedTensorsW4A4MoeMethod(CompressedTensorsMoEMethod): ) assert activation == "silu", "Only SiLU activation is supported." + if ( + self.allow_flashinfer + and self.flashinfer_moe_backend == FlashinferMoeBackend.TENSORRT_LLM + ): + return flashinfer_trtllm_fp4_moe( + layer=layer, + x=x, + router_logits=router_logits, + top_k=top_k, + global_num_experts=global_num_experts, + num_expert_group=num_expert_group, + topk_group=topk_group, + custom_routing_function=custom_routing_function, + e_score_correction_bias=e_score_correction_bias, + ) + topk_weights, topk_ids, _ = FusedMoE.select_experts( hidden_states=x, router_logits=router_logits, diff --git a/vllm/model_executor/layers/quantization/modelopt.py b/vllm/model_executor/layers/quantization/modelopt.py index 6b5ed7762eb31..01a23168bdde3 100644 --- a/vllm/model_executor/layers/quantization/modelopt.py +++ b/vllm/model_executor/layers/quantization/modelopt.py @@ -15,7 +15,6 @@ from vllm._custom_ops import cutlass_scaled_fp4_mm, scaled_fp4_quant from vllm.logger import init_logger from vllm.model_executor.layers.fused_moe.config import ( FusedMoEQuantConfig, - RoutingMethodType, fp8_w8a8_moe_quant_config, nvfp4_moe_quant_config, ) @@ -38,6 +37,8 @@ from vllm.model_executor.layers.quantization.base_config import ( from vllm.model_executor.layers.quantization.kv_cache import BaseKVCacheMethod from vllm.model_executor.layers.quantization.utils.flashinfer_fp4_moe import ( build_flashinfer_fp4_cutlass_moe_prepare_finalize, + flashinfer_trtllm_fp4_moe, + prepare_static_weights_for_trtllm_fp4_moe, reorder_w1w3_to_w3w1, select_nvfp4_gemm_impl, ) @@ -1136,7 +1137,6 @@ class ModelOptNvFp4FusedMoE(FusedMoEMethodBase): self.allow_flashinfer = _nvfp4.allow_flashinfer self.use_marlin = _nvfp4.use_marlin self.flashinfer_moe_backend = None - self._cache_permute_indices: dict[torch.Size, torch.Tensor] = {} if self.allow_flashinfer: self.flashinfer_moe_backend = get_flashinfer_moe_backend() logger.info_once( @@ -1303,138 +1303,14 @@ class ModelOptNvFp4FusedMoE(FusedMoEMethodBase): ) layer.register_parameter("w2_input_scale", w2_input_scale) - def prepare_static_weights_for_trtllm_fp4_moe( - self, - # args_dequant, - # args, - gemm1_weights, - gemm2_weights, - gemm1_scales_linear_fp4_bytes, - gemm2_scales_linear_fp4_bytes, - hidden_size, - intermediate_size, - num_experts, - ): - from flashinfer import nvfp4_block_scale_interleave - from flashinfer.fused_moe.core import ( - _maybe_get_cached_w3_w1_permute_indices, - get_w2_permute_indices_with_cache, - ) - - """Prepare quantized weights for kernel (done offline with weights).""" - epilogue_tile_m = 128 # FIXME: this depends on the kernel internals - - # Convert quantized weights to proper formats - gemm1_weights_fp4 = gemm1_weights.view(torch.float8_e4m3fn).reshape( - num_experts, 2 * intermediate_size, hidden_size // 2 - ) # packed fp4 - gemm1_scales_linear_fp4 = gemm1_scales_linear_fp4_bytes.view( - torch.float8_e4m3fn - ).reshape( - num_experts, 2 * intermediate_size, hidden_size // 16 - ) # fp8 scaling factors - - gemm2_weights_fp4 = gemm2_weights.view(torch.float8_e4m3fn).reshape( - num_experts, hidden_size, intermediate_size // 2 - ) # packed fp4 - gemm2_scales_linear_fp4 = gemm2_scales_linear_fp4_bytes.view( - torch.float8_e4m3fn - ).reshape( - num_experts, hidden_size, intermediate_size // 16 - ) # fp8 scaling factors - - gemm1_weights_fp4_shuffled = [] - gemm1_scales_fp4_shuffled = [] - gemm2_weights_fp4_shuffled = [] - gemm2_scales_fp4_shuffled = [] - for i in range(num_experts): - # Calculate the permute indices for the following: - # 1. Reorder rows of W1 and scales for fused gated activation - # 2. Shuffle weights and scaling factors for transposed mma output - # for both w3_w1 and w2 weights and scale factors - permute_indices = _maybe_get_cached_w3_w1_permute_indices( - self._cache_permute_indices, - gemm1_weights_fp4[i].view(torch.uint8), - epilogue_tile_m, - ) - gemm1_weights_fp4_shuffled.append( - gemm1_weights_fp4[i] - .view(torch.uint8)[permute_indices.to(gemm1_weights_fp4.device)] - .contiguous() - ) - - permute_sf_indices = _maybe_get_cached_w3_w1_permute_indices( - self._cache_permute_indices, - gemm1_scales_linear_fp4[i].view(torch.uint8), - epilogue_tile_m, - num_elts_per_sf=16, - ) - gemm1_scales_fp4_shuffled.append( - nvfp4_block_scale_interleave( - gemm1_scales_linear_fp4[i] - .view(torch.uint8)[ - permute_sf_indices.to(gemm1_scales_linear_fp4.device) - ] - .contiguous() - ) - ) - - permute_indices = get_w2_permute_indices_with_cache( - self._cache_permute_indices, - gemm2_weights_fp4[i].view(torch.uint8), - epilogue_tile_m, - ) - gemm2_weights_fp4_shuffled.append( - gemm2_weights_fp4[i] - .view(torch.uint8)[permute_indices.to(gemm2_weights_fp4.device)] - .contiguous() - ) - - permute_sf_indices = get_w2_permute_indices_with_cache( - self._cache_permute_indices, - gemm2_scales_linear_fp4[i].view(torch.uint8), - epilogue_tile_m, - num_elts_per_sf=16, - ) - gemm2_scales_fp4_shuffled.append( - nvfp4_block_scale_interleave( - gemm2_scales_linear_fp4[i] - .view(torch.uint8)[ - permute_sf_indices.to(gemm2_scales_linear_fp4.device) - ] - .contiguous() - ) - ) - - # Stack weights for all experts - gemm1_weights_fp4_shuffled = torch.stack(gemm1_weights_fp4_shuffled) - gemm1_scales_fp4_shuffled = ( - torch.stack(gemm1_scales_fp4_shuffled) - .view(torch.float8_e4m3fn) - .reshape(num_experts, 2 * intermediate_size, hidden_size // 16) - ) - - gemm2_weights_fp4_shuffled = torch.stack(gemm2_weights_fp4_shuffled) - gemm2_scales_fp4_shuffled = ( - torch.stack(gemm2_scales_fp4_shuffled) - .view(torch.float8_e4m3fn) - .reshape(num_experts, hidden_size, intermediate_size // 16) - ) - return ( - gemm1_weights_fp4_shuffled, - gemm1_scales_fp4_shuffled, - gemm2_weights_fp4_shuffled, - gemm2_scales_fp4_shuffled, - ) - def process_weights_after_loading(self, layer: torch.nn.Module) -> None: # GEMM 1 processing gemm1_weight = layer.w13_weight.data gemm1_weight_scale = layer.w13_weight_scale.data - if ( - self.allow_flashinfer - and self.flashinfer_moe_backend == FlashinferMoeBackend.CUTLASS + if self.allow_flashinfer and ( + self.flashinfer_moe_backend == FlashinferMoeBackend.CUTLASS + or self.flashinfer_moe_backend == FlashinferMoeBackend.TENSORRT_LLM ): gemm1_weight, gemm1_weight_scale = reorder_w1w3_to_w3w1( gemm1_weight, gemm1_weight_scale, dim=-2 @@ -1508,7 +1384,7 @@ class ModelOptNvFp4FusedMoE(FusedMoEMethodBase): gemm1_scales_fp4_shuffled, gemm2_weights_fp4_shuffled, gemm2_scales_fp4_shuffled, - ) = self.prepare_static_weights_for_trtllm_fp4_moe( + ) = prepare_static_weights_for_trtllm_fp4_moe( layer.w13_weight, layer.w2_weight, layer.w13_weight_scale, @@ -1614,68 +1490,17 @@ class ModelOptNvFp4FusedMoE(FusedMoEMethodBase): self.allow_flashinfer and self.flashinfer_moe_backend == FlashinferMoeBackend.TENSORRT_LLM ): - import flashinfer - - from vllm.model_executor.models.llama4 import Llama4MoE - - a1_gscale = layer.w13_input_scale_quant - (hidden_states_fp4, hidden_states_scale_linear_fp4) = ( - flashinfer.fp4_quantize( - x, - a1_gscale, - is_sf_swizzled_layout=False, - ) - ) - use_llama4_routing = ( - custom_routing_function is Llama4MoE.custom_routing_function - ) - routing_method_type = layer.routing_method_type - if use_llama4_routing: - routing_method_type = RoutingMethodType.Llama4 - router_logits = ( - router_logits.to(torch.float32) - if routing_method_type == RoutingMethodType.DeepSeekV3 - else router_logits - ) - routing_bias = e_score_correction_bias - if routing_bias is not None: - routing_bias = routing_bias.to(torch.bfloat16) - out = flashinfer.fused_moe.trtllm_fp4_block_scale_moe( - routing_logits=router_logits, - routing_bias=routing_bias, - hidden_states=hidden_states_fp4, - hidden_states_scale=hidden_states_scale_linear_fp4.view( - torch.float8_e4m3fn - ).flatten(), - gemm1_weights=layer.gemm1_weights_fp4_shuffled.data, - gemm1_weights_scale=layer.gemm1_scales_fp4_shuffled.data.view( - torch.float8_e4m3fn - ), - gemm1_bias=None, - gemm1_alpha=None, - gemm1_beta=None, - gemm1_clamp_limit=None, - gemm2_weights=layer.gemm2_weights_fp4_shuffled.data, - gemm2_weights_scale=layer.gemm2_scales_fp4_shuffled.data.view( - torch.float8_e4m3fn - ), - gemm2_bias=None, - output1_scale_scalar=layer.g1_scale_c.data, - output1_scale_gate_scalar=layer.g1_alphas.data, - output2_scale_scalar=layer.g2_alphas.data, - num_experts=global_num_experts, + return flashinfer_trtllm_fp4_moe( + layer=layer, + x=x, + router_logits=router_logits, top_k=top_k, - n_group=num_expert_group, + global_num_experts=global_num_experts, + num_expert_group=num_expert_group, topk_group=topk_group, - intermediate_size=layer.intermediate_size_per_partition, - local_expert_offset=layer.ep_rank * layer.local_num_experts, - local_num_experts=layer.local_num_experts, - routed_scaling_factor=1.0, - tile_tokens_dim=None, - routing_method_type=routing_method_type, - do_finalize=True, - )[0] - return out + custom_routing_function=custom_routing_function, + e_score_correction_bias=e_score_correction_bias, + ) topk_weights, topk_ids, _ = FusedMoE.select_experts( hidden_states=x, diff --git a/vllm/model_executor/layers/quantization/utils/flashinfer_fp4_moe.py b/vllm/model_executor/layers/quantization/utils/flashinfer_fp4_moe.py index 36e8599dd9484..eda40657b1e39 100644 --- a/vllm/model_executor/layers/quantization/utils/flashinfer_fp4_moe.py +++ b/vllm/model_executor/layers/quantization/utils/flashinfer_fp4_moe.py @@ -9,6 +9,7 @@ import vllm.model_executor.layers.fused_moe.modular_kernel as mk from vllm.model_executor.layers.fused_moe.config import ( FusedMoEConfig, FusedMoEQuantConfig, + RoutingMethodType, ) from vllm.model_executor.layers.fused_moe.flashinfer_cutedsl_moe import ( FlashInferCuteDSLExperts, @@ -110,3 +111,223 @@ def select_nvfp4_gemm_impl( "CutlassExpertsFp4 doesn't support DP. Use flashinfer CUTLASS " "Fused MoE backend instead (set VLLM_USE_FLASHINFER_MOE_FP4=1)" ) + + +def prepare_static_weights_for_trtllm_fp4_moe( + # args_dequant, + # args, + gemm1_weights, + gemm2_weights, + gemm1_scales_linear_fp4_bytes, + gemm2_scales_linear_fp4_bytes, + hidden_size, + intermediate_size, + num_experts, +): + from flashinfer import nvfp4_block_scale_interleave + from flashinfer.fused_moe.core import ( + _maybe_get_cached_w3_w1_permute_indices, + get_w2_permute_indices_with_cache, + ) + + _cache_permute_indices: dict[torch.Size, torch.Tensor] = {} + """Prepare quantized weights for kernel (done offline with weights).""" + epilogue_tile_m = 128 # FIXME: this depends on the kernel internals + + # Convert quantized weights to proper formats + gemm1_weights_fp4 = gemm1_weights.view(torch.float8_e4m3fn).reshape( + num_experts, 2 * intermediate_size, hidden_size // 2 + ) # packed fp4 + gemm1_scales_linear_fp4 = gemm1_scales_linear_fp4_bytes.view( + torch.float8_e4m3fn + ).reshape( + num_experts, 2 * intermediate_size, hidden_size // 16 + ) # fp8 scaling factors + + gemm2_weights_fp4 = gemm2_weights.view(torch.float8_e4m3fn).reshape( + num_experts, hidden_size, intermediate_size // 2 + ) # packed fp4 + gemm2_scales_linear_fp4 = gemm2_scales_linear_fp4_bytes.view( + torch.float8_e4m3fn + ).reshape(num_experts, hidden_size, intermediate_size // 16) # fp8 scaling factors + + gemm1_weights_fp4_shuffled = [] + gemm1_scales_fp4_shuffled = [] + gemm2_weights_fp4_shuffled = [] + gemm2_scales_fp4_shuffled = [] + for i in range(num_experts): + # Calculate the permute indices for the following: + # 1. Reorder rows of W1 and scales for fused gated activation + # 2. Shuffle weights and scaling factors for transposed mma output + # for both w3_w1 and w2 weights and scale factors + permute_indices = _maybe_get_cached_w3_w1_permute_indices( + _cache_permute_indices, + gemm1_weights_fp4[i].view(torch.uint8), + epilogue_tile_m, + ) + gemm1_weights_fp4_shuffled.append( + gemm1_weights_fp4[i] + .view(torch.uint8)[permute_indices.to(gemm1_weights_fp4.device)] + .contiguous() + ) + + permute_sf_indices = _maybe_get_cached_w3_w1_permute_indices( + _cache_permute_indices, + gemm1_scales_linear_fp4[i].view(torch.uint8), + epilogue_tile_m, + num_elts_per_sf=16, + ) + gemm1_scales_fp4_shuffled.append( + nvfp4_block_scale_interleave( + gemm1_scales_linear_fp4[i] + .view(torch.uint8)[ + permute_sf_indices.to(gemm1_scales_linear_fp4.device) + ] + .contiguous() + ) + ) + + permute_indices = get_w2_permute_indices_with_cache( + _cache_permute_indices, + gemm2_weights_fp4[i].view(torch.uint8), + epilogue_tile_m, + ) + gemm2_weights_fp4_shuffled.append( + gemm2_weights_fp4[i] + .view(torch.uint8)[permute_indices.to(gemm2_weights_fp4.device)] + .contiguous() + ) + + permute_sf_indices = get_w2_permute_indices_with_cache( + _cache_permute_indices, + gemm2_scales_linear_fp4[i].view(torch.uint8), + epilogue_tile_m, + num_elts_per_sf=16, + ) + gemm2_scales_fp4_shuffled.append( + nvfp4_block_scale_interleave( + gemm2_scales_linear_fp4[i] + .view(torch.uint8)[ + permute_sf_indices.to(gemm2_scales_linear_fp4.device) + ] + .contiguous() + ) + ) + + # Stack weights for all experts + gemm1_weights_fp4_shuffled = torch.stack(gemm1_weights_fp4_shuffled) + gemm1_scales_fp4_shuffled = ( + torch.stack(gemm1_scales_fp4_shuffled) + .view(torch.float8_e4m3fn) + .reshape(num_experts, 2 * intermediate_size, hidden_size // 16) + ) + + gemm2_weights_fp4_shuffled = torch.stack(gemm2_weights_fp4_shuffled) + gemm2_scales_fp4_shuffled = ( + torch.stack(gemm2_scales_fp4_shuffled) + .view(torch.float8_e4m3fn) + .reshape(num_experts, hidden_size, intermediate_size // 16) + ) + return ( + gemm1_weights_fp4_shuffled, + gemm1_scales_fp4_shuffled, + gemm2_weights_fp4_shuffled, + gemm2_scales_fp4_shuffled, + ) + + +def flashinfer_trtllm_fp4_moe( + layer: torch.nn.Module, + x: torch.Tensor, + router_logits: torch.Tensor, + top_k: int, + global_num_experts: int, + num_expert_group: int | None, + topk_group: int | None, + custom_routing_function: object | None, + e_score_correction_bias: torch.Tensor | None, +) -> torch.Tensor: + """ + Apply FlashInfer TensorRT-LLM FP4 MoE kernel. + + Args: + layer: The MoE layer with weights and scales + x: Input tensor + router_logits: Router logits for expert selection + top_k: Number of experts to select per token + global_num_experts: Total number of experts across all ranks + num_expert_group: Number of expert groups (for grouped routing) + topk_group: Top-k within each group + custom_routing_function: Custom routing function (e.g., Llama4) + e_score_correction_bias: Optional routing bias correction + + Returns: + Output tensor from the MoE layer + """ + import flashinfer + + from vllm.model_executor.models.llama4 import Llama4MoE + + # Quantize input to FP4 + a1_gscale = layer.w13_input_scale_quant + (hidden_states_fp4, hidden_states_scale_linear_fp4) = flashinfer.fp4_quantize( + x, + a1_gscale, + is_sf_swizzled_layout=False, + ) + + # Determine routing method type + use_llama4_routing = custom_routing_function is Llama4MoE.custom_routing_function + routing_method_type = layer.routing_method_type + if use_llama4_routing: + routing_method_type = flashinfer.RoutingMethodType.Llama4 + + # Prepare routing bias + routing_bias = e_score_correction_bias + if routing_bias is not None: + routing_bias = routing_bias.to(torch.bfloat16) + + router_logits = ( + router_logits.to(torch.float32) + if routing_method_type == RoutingMethodType.DeepSeekV3 + else router_logits + ) + + # Call TRT-LLM FP4 block-scale MoE kernel + out = flashinfer.fused_moe.trtllm_fp4_block_scale_moe( + routing_logits=router_logits, + routing_bias=routing_bias, + hidden_states=hidden_states_fp4, + hidden_states_scale=hidden_states_scale_linear_fp4.view( + torch.float8_e4m3fn + ).flatten(), + gemm1_weights=layer.gemm1_weights_fp4_shuffled.data, + gemm1_weights_scale=layer.gemm1_scales_fp4_shuffled.data.view( + torch.float8_e4m3fn + ), + gemm1_bias=None, + gemm1_alpha=None, + gemm1_beta=None, + gemm1_clamp_limit=None, + gemm2_weights=layer.gemm2_weights_fp4_shuffled.data, + gemm2_weights_scale=layer.gemm2_scales_fp4_shuffled.data.view( + torch.float8_e4m3fn + ), + gemm2_bias=None, + output1_scale_scalar=layer.g1_scale_c.data, + output1_scale_gate_scalar=layer.g1_alphas.data, + output2_scale_scalar=layer.g2_alphas.data, + num_experts=global_num_experts, + top_k=top_k, + n_group=num_expert_group if num_expert_group is not None else 0, + topk_group=topk_group if topk_group is not None else 0, + intermediate_size=layer.intermediate_size_per_partition, + local_expert_offset=layer.ep_rank * layer.local_num_experts, + local_num_experts=layer.local_num_experts, + routed_scaling_factor=None, + tile_tokens_dim=None, + routing_method_type=routing_method_type, + do_finalize=True, + )[0] + + return out