[Performance][Fix] update nvfp4 code to support renorm routing (#28569)

Signed-off-by: jiahanc <173873397+jiahanc@users.noreply.github.com>
Co-authored-by: Michael Goin <mgoin64@gmail.com>
This commit is contained in:
jiahanc 2025-11-16 18:02:42 -08:00 committed by GitHub
parent 80b6080ddc
commit 561253b37f
No known key found for this signature in database
GPG Key ID: B5690EEEBB952194
2 changed files with 15 additions and 8 deletions

View File

@ -15,6 +15,7 @@ from vllm.logger import init_logger
from vllm.model_executor.layers.fused_moe.config import (
FusedMoEConfig,
FusedMoEQuantConfig,
RoutingMethodType,
fp8_w8a8_moe_quant_config,
nvfp4_moe_quant_config,
)
@ -1657,16 +1658,19 @@ class ModelOptNvFp4FusedMoE(FusedMoEMethodBase):
use_llama4_routing = (
custom_routing_function is Llama4MoE.custom_routing_function
)
routing_method_type = flashinfer.RoutingMethodType.DeepSeekV3
routing_method_type = layer.routing_method_type
if use_llama4_routing:
routing_method_type = flashinfer.RoutingMethodType.Llama4
routing_method_type = RoutingMethodType.Llama4
router_logits = (
router_logits.to(torch.float32)
if routing_method_type == RoutingMethodType.DeepSeekV3
else router_logits
)
routing_bias = e_score_correction_bias
if routing_bias is not None:
routing_bias = routing_bias.to(torch.bfloat16)
out = flashinfer.fused_moe.trtllm_fp4_block_scale_moe(
routing_logits=router_logits
if use_llama4_routing
else router_logits.to(torch.float32),
routing_logits=router_logits,
routing_bias=routing_bias,
hidden_states=hidden_states_fp4,
hidden_states_scale=hidden_states_scale_linear_fp4.view(
@ -1690,8 +1694,8 @@ class ModelOptNvFp4FusedMoE(FusedMoEMethodBase):
output2_scale_scalar=layer.g2_alphas.data,
num_experts=global_num_experts,
top_k=top_k,
n_group=num_expert_group if num_expert_group is not None else 0,
topk_group=topk_group if topk_group is not None else 0,
n_group=num_expert_group,
topk_group=topk_group,
intermediate_size=layer.intermediate_size_per_partition,
local_expert_offset=layer.ep_rank * layer.local_num_experts,
local_num_experts=layer.local_num_experts,

View File

@ -291,5 +291,8 @@ def get_flashinfer_moe_backend() -> FlashinferMoeBackend:
def is_flashinfer_supporting_global_sf(backend: FlashinferMoeBackend | None) -> bool:
# TODO(shuw@nvidia): Update when new backends are added.
backends_supporting_global_sf = (FlashinferMoeBackend.CUTLASS,)
backends_supporting_global_sf = (
FlashinferMoeBackend.CUTLASS,
FlashinferMoeBackend.TENSORRT_LLM,
)
return backend in backends_supporting_global_sf