Add TRTLLM MoE NVFP4 kernel to CompressedTensorsW4A4MoeMethod (#28892)

Signed-off-by: mingyuanm <mingyuanm@nvidia.com>
Signed-off-by: mgoin <mgoin64@gmail.com>
Co-authored-by: mgoin <mgoin64@gmail.com>
This commit is contained in:
Mingyuan Ma 2025-11-21 08:54:11 -08:00 committed by GitHub
parent e99e467384
commit b4c8fbaae2
No known key found for this signature in database
GPG Key ID: B5690EEEBB952194
3 changed files with 358 additions and 210 deletions

View File

@ -8,6 +8,7 @@ from enum import Enum
import torch
from compressed_tensors import CompressionFormat
from compressed_tensors.quantization import ActivationOrdering, QuantizationStrategy
from torch.nn.parameter import Parameter
import vllm.envs as envs
import vllm.model_executor.layers.fused_moe.modular_kernel as mk
@ -50,9 +51,15 @@ from vllm.model_executor.layers.quantization.compressed_tensors.utils import (
from vllm.model_executor.layers.quantization.utils import replace_parameter
from vllm.model_executor.layers.quantization.utils.flashinfer_fp4_moe import (
build_flashinfer_fp4_cutlass_moe_prepare_finalize,
flashinfer_trtllm_fp4_moe,
prepare_static_weights_for_trtllm_fp4_moe,
reorder_w1w3_to_w3w1,
select_nvfp4_gemm_impl,
)
from vllm.model_executor.layers.quantization.utils.flashinfer_utils import (
FlashinferMoeBackend,
get_flashinfer_moe_backend,
)
from vllm.model_executor.layers.quantization.utils.fp8_utils import (
expert_weight_is_col_major,
requant_weight_ue8m0_inplace,
@ -193,6 +200,13 @@ class CompressedTensorsW4A4MoeMethod(CompressedTensorsMoEMethod):
self.allow_flashinfer = _nvfp4.allow_flashinfer
self.use_marlin = _nvfp4.use_marlin
self.group_size = 16
self.flashinfer_moe_backend = None
if self.allow_flashinfer:
self.flashinfer_moe_backend = get_flashinfer_moe_backend()
logger.info_once(
f"Using FlashInfer {self.flashinfer_moe_backend.value} kernels"
" for CompressedTensorsW4A4MoeMethod."
)
def create_weights(
self,
@ -344,21 +358,20 @@ class CompressedTensorsW4A4MoeMethod(CompressedTensorsMoEMethod):
if self.use_marlin:
prepare_moe_fp4_layer_for_marlin(layer)
return
# swizzle weight scales
layer.w13_weight_scale = torch.nn.Parameter(
swizzle_blockscale(layer.w13_weight_scale), requires_grad=False
)
layer.w2_weight_scale = torch.nn.Parameter(
swizzle_blockscale(layer.w2_weight_scale), requires_grad=False
)
# w13
w13_input_global_scale = layer.w13_input_global_scale.max(dim=1).values.to(
torch.float32
)
if (
self.allow_flashinfer
and self.flashinfer_moe_backend == FlashinferMoeBackend.TENSORRT_LLM
):
w13_input_global_scale = (
layer.w13_input_global_scale.min()
.to(torch.float32)
.expand(layer.num_experts)
)
else:
w13_input_global_scale = layer.w13_input_global_scale.min(dim=1).values.to(
torch.float32
)
layer.g1_alphas = torch.nn.Parameter(
((1 / w13_input_global_scale) * layer.w13_weight_scale_2),
requires_grad=False,
@ -369,22 +382,92 @@ class CompressedTensorsW4A4MoeMethod(CompressedTensorsMoEMethod):
)
# w2
if (
self.allow_flashinfer
and self.flashinfer_moe_backend == FlashinferMoeBackend.TENSORRT_LLM
):
w2_input_global_scale = (
layer.w2_input_global_scale.min()
.to(torch.float32)
.expand(layer.num_experts)
)
else:
w2_input_global_scale = layer.w2_input_global_scale
layer.g2_alphas = torch.nn.Parameter(
((1 / layer.w2_input_global_scale) * layer.w2_weight_scale_2).to(
torch.float32
),
((1 / w2_input_global_scale) * layer.w2_weight_scale_2).to(torch.float32),
requires_grad=False,
)
layer.w2_input_scale_quant = torch.nn.Parameter(
(layer.w2_input_global_scale), requires_grad=False
(w2_input_global_scale), requires_grad=False
)
# TensorRT-LLM specific processing
if (
self.allow_flashinfer
and self.flashinfer_moe_backend == FlashinferMoeBackend.TENSORRT_LLM
):
# Prepare static weights for TRT-LLM kernel
# alternate: prepare_static_weight_layouts_for_trtllm_moe
(
gemm1_weights_fp4_shuffled,
gemm1_scales_fp4_shuffled,
gemm2_weights_fp4_shuffled,
gemm2_scales_fp4_shuffled,
) = prepare_static_weights_for_trtllm_fp4_moe(
layer.w13_weight,
layer.w2_weight,
layer.w13_weight_scale,
layer.w2_weight_scale,
layer.w2_weight.size(-2), # hidden_size
layer.w13_weight.size(-2) // 2, # intermediate_size
layer.w13_weight.size(0), # num_experts
)
logger.debug_once("Finished shuffling weights for TRT-LLM MOE")
layer.gemm1_weights_fp4_shuffled = Parameter(
gemm1_weights_fp4_shuffled, requires_grad=False
)
layer.gemm2_weights_fp4_shuffled = Parameter(
gemm2_weights_fp4_shuffled, requires_grad=False
)
layer.gemm1_scales_fp4_shuffled = Parameter(
gemm1_scales_fp4_shuffled, requires_grad=False
)
layer.gemm2_scales_fp4_shuffled = Parameter(
gemm2_scales_fp4_shuffled, requires_grad=False
)
# Additional parameter needed for TRT-LLM
layer.g1_scale_c = Parameter(
(layer.w2_input_scale_quant * layer.g1_alphas).to(torch.float32),
requires_grad=False,
)
# Clean up weights that won't be used by TRT-LLM
del layer.w2_weight
del layer.w2_weight_scale
del layer.w13_weight
del layer.w13_weight_scale
else:
# swizzle weight scales
layer.w13_weight_scale = torch.nn.Parameter(
swizzle_blockscale(layer.w13_weight_scale), requires_grad=False
)
layer.w2_weight_scale = torch.nn.Parameter(
swizzle_blockscale(layer.w2_weight_scale), requires_grad=False
)
def maybe_make_prepare_finalize(
self,
routing_tables: tuple[torch.Tensor, torch.Tensor, torch.Tensor] | None = None,
) -> mk.FusedMoEPrepareAndFinalize | None:
if self.use_marlin:
if self.use_marlin or (
self.allow_flashinfer
and self.flashinfer_moe_backend == FlashinferMoeBackend.TENSORRT_LLM
):
return None
elif not self.allow_flashinfer:
return super().maybe_make_prepare_finalize(routing_tables)
@ -411,7 +494,10 @@ class CompressedTensorsW4A4MoeMethod(CompressedTensorsMoEMethod):
def get_fused_moe_quant_config(
self, layer: torch.nn.Module
) -> FusedMoEQuantConfig | None:
if self.use_marlin:
if (
self.use_marlin
or self.flashinfer_moe_backend == FlashinferMoeBackend.TENSORRT_LLM
):
return None
return nvfp4_moe_quant_config(
@ -452,6 +538,22 @@ class CompressedTensorsW4A4MoeMethod(CompressedTensorsMoEMethod):
)
assert activation == "silu", "Only SiLU activation is supported."
if (
self.allow_flashinfer
and self.flashinfer_moe_backend == FlashinferMoeBackend.TENSORRT_LLM
):
return flashinfer_trtllm_fp4_moe(
layer=layer,
x=x,
router_logits=router_logits,
top_k=top_k,
global_num_experts=global_num_experts,
num_expert_group=num_expert_group,
topk_group=topk_group,
custom_routing_function=custom_routing_function,
e_score_correction_bias=e_score_correction_bias,
)
topk_weights, topk_ids, _ = FusedMoE.select_experts(
hidden_states=x,
router_logits=router_logits,

View File

@ -15,7 +15,6 @@ from vllm._custom_ops import cutlass_scaled_fp4_mm, scaled_fp4_quant
from vllm.logger import init_logger
from vllm.model_executor.layers.fused_moe.config import (
FusedMoEQuantConfig,
RoutingMethodType,
fp8_w8a8_moe_quant_config,
nvfp4_moe_quant_config,
)
@ -38,6 +37,8 @@ from vllm.model_executor.layers.quantization.base_config import (
from vllm.model_executor.layers.quantization.kv_cache import BaseKVCacheMethod
from vllm.model_executor.layers.quantization.utils.flashinfer_fp4_moe import (
build_flashinfer_fp4_cutlass_moe_prepare_finalize,
flashinfer_trtllm_fp4_moe,
prepare_static_weights_for_trtllm_fp4_moe,
reorder_w1w3_to_w3w1,
select_nvfp4_gemm_impl,
)
@ -1136,7 +1137,6 @@ class ModelOptNvFp4FusedMoE(FusedMoEMethodBase):
self.allow_flashinfer = _nvfp4.allow_flashinfer
self.use_marlin = _nvfp4.use_marlin
self.flashinfer_moe_backend = None
self._cache_permute_indices: dict[torch.Size, torch.Tensor] = {}
if self.allow_flashinfer:
self.flashinfer_moe_backend = get_flashinfer_moe_backend()
logger.info_once(
@ -1303,138 +1303,14 @@ class ModelOptNvFp4FusedMoE(FusedMoEMethodBase):
)
layer.register_parameter("w2_input_scale", w2_input_scale)
def prepare_static_weights_for_trtllm_fp4_moe(
self,
# args_dequant,
# args,
gemm1_weights,
gemm2_weights,
gemm1_scales_linear_fp4_bytes,
gemm2_scales_linear_fp4_bytes,
hidden_size,
intermediate_size,
num_experts,
):
from flashinfer import nvfp4_block_scale_interleave
from flashinfer.fused_moe.core import (
_maybe_get_cached_w3_w1_permute_indices,
get_w2_permute_indices_with_cache,
)
"""Prepare quantized weights for kernel (done offline with weights)."""
epilogue_tile_m = 128 # FIXME: this depends on the kernel internals
# Convert quantized weights to proper formats
gemm1_weights_fp4 = gemm1_weights.view(torch.float8_e4m3fn).reshape(
num_experts, 2 * intermediate_size, hidden_size // 2
) # packed fp4
gemm1_scales_linear_fp4 = gemm1_scales_linear_fp4_bytes.view(
torch.float8_e4m3fn
).reshape(
num_experts, 2 * intermediate_size, hidden_size // 16
) # fp8 scaling factors
gemm2_weights_fp4 = gemm2_weights.view(torch.float8_e4m3fn).reshape(
num_experts, hidden_size, intermediate_size // 2
) # packed fp4
gemm2_scales_linear_fp4 = gemm2_scales_linear_fp4_bytes.view(
torch.float8_e4m3fn
).reshape(
num_experts, hidden_size, intermediate_size // 16
) # fp8 scaling factors
gemm1_weights_fp4_shuffled = []
gemm1_scales_fp4_shuffled = []
gemm2_weights_fp4_shuffled = []
gemm2_scales_fp4_shuffled = []
for i in range(num_experts):
# Calculate the permute indices for the following:
# 1. Reorder rows of W1 and scales for fused gated activation
# 2. Shuffle weights and scaling factors for transposed mma output
# for both w3_w1 and w2 weights and scale factors
permute_indices = _maybe_get_cached_w3_w1_permute_indices(
self._cache_permute_indices,
gemm1_weights_fp4[i].view(torch.uint8),
epilogue_tile_m,
)
gemm1_weights_fp4_shuffled.append(
gemm1_weights_fp4[i]
.view(torch.uint8)[permute_indices.to(gemm1_weights_fp4.device)]
.contiguous()
)
permute_sf_indices = _maybe_get_cached_w3_w1_permute_indices(
self._cache_permute_indices,
gemm1_scales_linear_fp4[i].view(torch.uint8),
epilogue_tile_m,
num_elts_per_sf=16,
)
gemm1_scales_fp4_shuffled.append(
nvfp4_block_scale_interleave(
gemm1_scales_linear_fp4[i]
.view(torch.uint8)[
permute_sf_indices.to(gemm1_scales_linear_fp4.device)
]
.contiguous()
)
)
permute_indices = get_w2_permute_indices_with_cache(
self._cache_permute_indices,
gemm2_weights_fp4[i].view(torch.uint8),
epilogue_tile_m,
)
gemm2_weights_fp4_shuffled.append(
gemm2_weights_fp4[i]
.view(torch.uint8)[permute_indices.to(gemm2_weights_fp4.device)]
.contiguous()
)
permute_sf_indices = get_w2_permute_indices_with_cache(
self._cache_permute_indices,
gemm2_scales_linear_fp4[i].view(torch.uint8),
epilogue_tile_m,
num_elts_per_sf=16,
)
gemm2_scales_fp4_shuffled.append(
nvfp4_block_scale_interleave(
gemm2_scales_linear_fp4[i]
.view(torch.uint8)[
permute_sf_indices.to(gemm2_scales_linear_fp4.device)
]
.contiguous()
)
)
# Stack weights for all experts
gemm1_weights_fp4_shuffled = torch.stack(gemm1_weights_fp4_shuffled)
gemm1_scales_fp4_shuffled = (
torch.stack(gemm1_scales_fp4_shuffled)
.view(torch.float8_e4m3fn)
.reshape(num_experts, 2 * intermediate_size, hidden_size // 16)
)
gemm2_weights_fp4_shuffled = torch.stack(gemm2_weights_fp4_shuffled)
gemm2_scales_fp4_shuffled = (
torch.stack(gemm2_scales_fp4_shuffled)
.view(torch.float8_e4m3fn)
.reshape(num_experts, hidden_size, intermediate_size // 16)
)
return (
gemm1_weights_fp4_shuffled,
gemm1_scales_fp4_shuffled,
gemm2_weights_fp4_shuffled,
gemm2_scales_fp4_shuffled,
)
def process_weights_after_loading(self, layer: torch.nn.Module) -> None:
# GEMM 1 processing
gemm1_weight = layer.w13_weight.data
gemm1_weight_scale = layer.w13_weight_scale.data
if (
self.allow_flashinfer
and self.flashinfer_moe_backend == FlashinferMoeBackend.CUTLASS
if self.allow_flashinfer and (
self.flashinfer_moe_backend == FlashinferMoeBackend.CUTLASS
or self.flashinfer_moe_backend == FlashinferMoeBackend.TENSORRT_LLM
):
gemm1_weight, gemm1_weight_scale = reorder_w1w3_to_w3w1(
gemm1_weight, gemm1_weight_scale, dim=-2
@ -1508,7 +1384,7 @@ class ModelOptNvFp4FusedMoE(FusedMoEMethodBase):
gemm1_scales_fp4_shuffled,
gemm2_weights_fp4_shuffled,
gemm2_scales_fp4_shuffled,
) = self.prepare_static_weights_for_trtllm_fp4_moe(
) = prepare_static_weights_for_trtllm_fp4_moe(
layer.w13_weight,
layer.w2_weight,
layer.w13_weight_scale,
@ -1614,68 +1490,17 @@ class ModelOptNvFp4FusedMoE(FusedMoEMethodBase):
self.allow_flashinfer
and self.flashinfer_moe_backend == FlashinferMoeBackend.TENSORRT_LLM
):
import flashinfer
from vllm.model_executor.models.llama4 import Llama4MoE
a1_gscale = layer.w13_input_scale_quant
(hidden_states_fp4, hidden_states_scale_linear_fp4) = (
flashinfer.fp4_quantize(
x,
a1_gscale,
is_sf_swizzled_layout=False,
)
)
use_llama4_routing = (
custom_routing_function is Llama4MoE.custom_routing_function
)
routing_method_type = layer.routing_method_type
if use_llama4_routing:
routing_method_type = RoutingMethodType.Llama4
router_logits = (
router_logits.to(torch.float32)
if routing_method_type == RoutingMethodType.DeepSeekV3
else router_logits
)
routing_bias = e_score_correction_bias
if routing_bias is not None:
routing_bias = routing_bias.to(torch.bfloat16)
out = flashinfer.fused_moe.trtllm_fp4_block_scale_moe(
routing_logits=router_logits,
routing_bias=routing_bias,
hidden_states=hidden_states_fp4,
hidden_states_scale=hidden_states_scale_linear_fp4.view(
torch.float8_e4m3fn
).flatten(),
gemm1_weights=layer.gemm1_weights_fp4_shuffled.data,
gemm1_weights_scale=layer.gemm1_scales_fp4_shuffled.data.view(
torch.float8_e4m3fn
),
gemm1_bias=None,
gemm1_alpha=None,
gemm1_beta=None,
gemm1_clamp_limit=None,
gemm2_weights=layer.gemm2_weights_fp4_shuffled.data,
gemm2_weights_scale=layer.gemm2_scales_fp4_shuffled.data.view(
torch.float8_e4m3fn
),
gemm2_bias=None,
output1_scale_scalar=layer.g1_scale_c.data,
output1_scale_gate_scalar=layer.g1_alphas.data,
output2_scale_scalar=layer.g2_alphas.data,
num_experts=global_num_experts,
return flashinfer_trtllm_fp4_moe(
layer=layer,
x=x,
router_logits=router_logits,
top_k=top_k,
n_group=num_expert_group,
global_num_experts=global_num_experts,
num_expert_group=num_expert_group,
topk_group=topk_group,
intermediate_size=layer.intermediate_size_per_partition,
local_expert_offset=layer.ep_rank * layer.local_num_experts,
local_num_experts=layer.local_num_experts,
routed_scaling_factor=1.0,
tile_tokens_dim=None,
routing_method_type=routing_method_type,
do_finalize=True,
)[0]
return out
custom_routing_function=custom_routing_function,
e_score_correction_bias=e_score_correction_bias,
)
topk_weights, topk_ids, _ = FusedMoE.select_experts(
hidden_states=x,

View File

@ -9,6 +9,7 @@ import vllm.model_executor.layers.fused_moe.modular_kernel as mk
from vllm.model_executor.layers.fused_moe.config import (
FusedMoEConfig,
FusedMoEQuantConfig,
RoutingMethodType,
)
from vllm.model_executor.layers.fused_moe.flashinfer_cutedsl_moe import (
FlashInferCuteDSLExperts,
@ -110,3 +111,223 @@ def select_nvfp4_gemm_impl(
"CutlassExpertsFp4 doesn't support DP. Use flashinfer CUTLASS "
"Fused MoE backend instead (set VLLM_USE_FLASHINFER_MOE_FP4=1)"
)
def prepare_static_weights_for_trtllm_fp4_moe(
# args_dequant,
# args,
gemm1_weights,
gemm2_weights,
gemm1_scales_linear_fp4_bytes,
gemm2_scales_linear_fp4_bytes,
hidden_size,
intermediate_size,
num_experts,
):
from flashinfer import nvfp4_block_scale_interleave
from flashinfer.fused_moe.core import (
_maybe_get_cached_w3_w1_permute_indices,
get_w2_permute_indices_with_cache,
)
_cache_permute_indices: dict[torch.Size, torch.Tensor] = {}
"""Prepare quantized weights for kernel (done offline with weights)."""
epilogue_tile_m = 128 # FIXME: this depends on the kernel internals
# Convert quantized weights to proper formats
gemm1_weights_fp4 = gemm1_weights.view(torch.float8_e4m3fn).reshape(
num_experts, 2 * intermediate_size, hidden_size // 2
) # packed fp4
gemm1_scales_linear_fp4 = gemm1_scales_linear_fp4_bytes.view(
torch.float8_e4m3fn
).reshape(
num_experts, 2 * intermediate_size, hidden_size // 16
) # fp8 scaling factors
gemm2_weights_fp4 = gemm2_weights.view(torch.float8_e4m3fn).reshape(
num_experts, hidden_size, intermediate_size // 2
) # packed fp4
gemm2_scales_linear_fp4 = gemm2_scales_linear_fp4_bytes.view(
torch.float8_e4m3fn
).reshape(num_experts, hidden_size, intermediate_size // 16) # fp8 scaling factors
gemm1_weights_fp4_shuffled = []
gemm1_scales_fp4_shuffled = []
gemm2_weights_fp4_shuffled = []
gemm2_scales_fp4_shuffled = []
for i in range(num_experts):
# Calculate the permute indices for the following:
# 1. Reorder rows of W1 and scales for fused gated activation
# 2. Shuffle weights and scaling factors for transposed mma output
# for both w3_w1 and w2 weights and scale factors
permute_indices = _maybe_get_cached_w3_w1_permute_indices(
_cache_permute_indices,
gemm1_weights_fp4[i].view(torch.uint8),
epilogue_tile_m,
)
gemm1_weights_fp4_shuffled.append(
gemm1_weights_fp4[i]
.view(torch.uint8)[permute_indices.to(gemm1_weights_fp4.device)]
.contiguous()
)
permute_sf_indices = _maybe_get_cached_w3_w1_permute_indices(
_cache_permute_indices,
gemm1_scales_linear_fp4[i].view(torch.uint8),
epilogue_tile_m,
num_elts_per_sf=16,
)
gemm1_scales_fp4_shuffled.append(
nvfp4_block_scale_interleave(
gemm1_scales_linear_fp4[i]
.view(torch.uint8)[
permute_sf_indices.to(gemm1_scales_linear_fp4.device)
]
.contiguous()
)
)
permute_indices = get_w2_permute_indices_with_cache(
_cache_permute_indices,
gemm2_weights_fp4[i].view(torch.uint8),
epilogue_tile_m,
)
gemm2_weights_fp4_shuffled.append(
gemm2_weights_fp4[i]
.view(torch.uint8)[permute_indices.to(gemm2_weights_fp4.device)]
.contiguous()
)
permute_sf_indices = get_w2_permute_indices_with_cache(
_cache_permute_indices,
gemm2_scales_linear_fp4[i].view(torch.uint8),
epilogue_tile_m,
num_elts_per_sf=16,
)
gemm2_scales_fp4_shuffled.append(
nvfp4_block_scale_interleave(
gemm2_scales_linear_fp4[i]
.view(torch.uint8)[
permute_sf_indices.to(gemm2_scales_linear_fp4.device)
]
.contiguous()
)
)
# Stack weights for all experts
gemm1_weights_fp4_shuffled = torch.stack(gemm1_weights_fp4_shuffled)
gemm1_scales_fp4_shuffled = (
torch.stack(gemm1_scales_fp4_shuffled)
.view(torch.float8_e4m3fn)
.reshape(num_experts, 2 * intermediate_size, hidden_size // 16)
)
gemm2_weights_fp4_shuffled = torch.stack(gemm2_weights_fp4_shuffled)
gemm2_scales_fp4_shuffled = (
torch.stack(gemm2_scales_fp4_shuffled)
.view(torch.float8_e4m3fn)
.reshape(num_experts, hidden_size, intermediate_size // 16)
)
return (
gemm1_weights_fp4_shuffled,
gemm1_scales_fp4_shuffled,
gemm2_weights_fp4_shuffled,
gemm2_scales_fp4_shuffled,
)
def flashinfer_trtllm_fp4_moe(
layer: torch.nn.Module,
x: torch.Tensor,
router_logits: torch.Tensor,
top_k: int,
global_num_experts: int,
num_expert_group: int | None,
topk_group: int | None,
custom_routing_function: object | None,
e_score_correction_bias: torch.Tensor | None,
) -> torch.Tensor:
"""
Apply FlashInfer TensorRT-LLM FP4 MoE kernel.
Args:
layer: The MoE layer with weights and scales
x: Input tensor
router_logits: Router logits for expert selection
top_k: Number of experts to select per token
global_num_experts: Total number of experts across all ranks
num_expert_group: Number of expert groups (for grouped routing)
topk_group: Top-k within each group
custom_routing_function: Custom routing function (e.g., Llama4)
e_score_correction_bias: Optional routing bias correction
Returns:
Output tensor from the MoE layer
"""
import flashinfer
from vllm.model_executor.models.llama4 import Llama4MoE
# Quantize input to FP4
a1_gscale = layer.w13_input_scale_quant
(hidden_states_fp4, hidden_states_scale_linear_fp4) = flashinfer.fp4_quantize(
x,
a1_gscale,
is_sf_swizzled_layout=False,
)
# Determine routing method type
use_llama4_routing = custom_routing_function is Llama4MoE.custom_routing_function
routing_method_type = layer.routing_method_type
if use_llama4_routing:
routing_method_type = flashinfer.RoutingMethodType.Llama4
# Prepare routing bias
routing_bias = e_score_correction_bias
if routing_bias is not None:
routing_bias = routing_bias.to(torch.bfloat16)
router_logits = (
router_logits.to(torch.float32)
if routing_method_type == RoutingMethodType.DeepSeekV3
else router_logits
)
# Call TRT-LLM FP4 block-scale MoE kernel
out = flashinfer.fused_moe.trtllm_fp4_block_scale_moe(
routing_logits=router_logits,
routing_bias=routing_bias,
hidden_states=hidden_states_fp4,
hidden_states_scale=hidden_states_scale_linear_fp4.view(
torch.float8_e4m3fn
).flatten(),
gemm1_weights=layer.gemm1_weights_fp4_shuffled.data,
gemm1_weights_scale=layer.gemm1_scales_fp4_shuffled.data.view(
torch.float8_e4m3fn
),
gemm1_bias=None,
gemm1_alpha=None,
gemm1_beta=None,
gemm1_clamp_limit=None,
gemm2_weights=layer.gemm2_weights_fp4_shuffled.data,
gemm2_weights_scale=layer.gemm2_scales_fp4_shuffled.data.view(
torch.float8_e4m3fn
),
gemm2_bias=None,
output1_scale_scalar=layer.g1_scale_c.data,
output1_scale_gate_scalar=layer.g1_alphas.data,
output2_scale_scalar=layer.g2_alphas.data,
num_experts=global_num_experts,
top_k=top_k,
n_group=num_expert_group if num_expert_group is not None else 0,
topk_group=topk_group if topk_group is not None else 0,
intermediate_size=layer.intermediate_size_per_partition,
local_expert_offset=layer.ep_rank * layer.local_num_experts,
local_num_experts=layer.local_num_experts,
routed_scaling_factor=None,
tile_tokens_dim=None,
routing_method_type=routing_method_type,
do_finalize=True,
)[0]
return out