From 081b5594a2b1a37ea793659bb6767c497beef45d Mon Sep 17 00:00:00 2001 From: Shu Wang Date: Thu, 25 Sep 2025 18:35:14 -0500 Subject: [PATCH] Fix routing_bias dtype (#25711) Signed-off-by: Shu Wang. --- vllm/model_executor/layers/quantization/modelopt.py | 5 ++++- 1 file changed, 4 insertions(+), 1 deletion(-) diff --git a/vllm/model_executor/layers/quantization/modelopt.py b/vllm/model_executor/layers/quantization/modelopt.py index 4491fcf18106d..0be43da00b533 100644 --- a/vllm/model_executor/layers/quantization/modelopt.py +++ b/vllm/model_executor/layers/quantization/modelopt.py @@ -1454,10 +1454,13 @@ class ModelOptNvFp4FusedMoE(FusedMoEMethodBase): routing_method_type = flashinfer.RoutingMethodType.DeepSeekV3 if use_llama4_routing: routing_method_type = flashinfer.RoutingMethodType.Llama4 + routing_bias = e_score_correction_bias + if routing_bias is not None: + routing_bias = routing_bias.to(torch.bfloat16) out = flashinfer.fused_moe.trtllm_fp4_block_scale_moe( routing_logits=router_logits if use_llama4_routing else router_logits.to(torch.float32), - routing_bias=e_score_correction_bias, + routing_bias=routing_bias, hidden_states=hidden_states_fp4, hidden_states_scale=hidden_states_scale_linear_fp4.view( torch.float8_e4m3fn).flatten(),