Fix routing_bias dtype (#25711)

Signed-off-by: Shu Wang. <shuw@nvidia.com>
This commit is contained in:
Shu Wang 2025-09-25 18:35:14 -05:00 committed by GitHub
parent 57329a8c01
commit 081b5594a2
No known key found for this signature in database
GPG Key ID: B5690EEEBB952194

View File

@ -1454,10 +1454,13 @@ class ModelOptNvFp4FusedMoE(FusedMoEMethodBase):
routing_method_type = flashinfer.RoutingMethodType.DeepSeekV3
if use_llama4_routing:
routing_method_type = flashinfer.RoutingMethodType.Llama4
routing_bias = e_score_correction_bias
if routing_bias is not None:
routing_bias = routing_bias.to(torch.bfloat16)
out = flashinfer.fused_moe.trtllm_fp4_block_scale_moe(
routing_logits=router_logits
if use_llama4_routing else router_logits.to(torch.float32),
routing_bias=e_score_correction_bias,
routing_bias=routing_bias,
hidden_states=hidden_states_fp4,
hidden_states_scale=hidden_states_scale_linear_fp4.view(
torch.float8_e4m3fn).flatten(),