diff --git a/vllm/model_executor/layers/quantization/modelopt.py b/vllm/model_executor/layers/quantization/modelopt.py index 4491fcf18106d..0be43da00b533 100644 --- a/vllm/model_executor/layers/quantization/modelopt.py +++ b/vllm/model_executor/layers/quantization/modelopt.py @@ -1454,10 +1454,13 @@ class ModelOptNvFp4FusedMoE(FusedMoEMethodBase): routing_method_type = flashinfer.RoutingMethodType.DeepSeekV3 if use_llama4_routing: routing_method_type = flashinfer.RoutingMethodType.Llama4 + routing_bias = e_score_correction_bias + if routing_bias is not None: + routing_bias = routing_bias.to(torch.bfloat16) out = flashinfer.fused_moe.trtllm_fp4_block_scale_moe( routing_logits=router_logits if use_llama4_routing else router_logits.to(torch.float32), - routing_bias=e_score_correction_bias, + routing_bias=routing_bias, hidden_states=hidden_states_fp4, hidden_states_scale=hidden_states_scale_linear_fp4.view( torch.float8_e4m3fn).flatten(),