diff --git a/vllm/model_executor/layers/quantization/modelopt.py b/vllm/model_executor/layers/quantization/modelopt.py index cf6325eb85dfd..476521813f464 100644 --- a/vllm/model_executor/layers/quantization/modelopt.py +++ b/vllm/model_executor/layers/quantization/modelopt.py @@ -15,6 +15,7 @@ from vllm.logger import init_logger from vllm.model_executor.layers.fused_moe.config import ( FusedMoEConfig, FusedMoEQuantConfig, + RoutingMethodType, fp8_w8a8_moe_quant_config, nvfp4_moe_quant_config, ) @@ -1657,16 +1658,19 @@ class ModelOptNvFp4FusedMoE(FusedMoEMethodBase): use_llama4_routing = ( custom_routing_function is Llama4MoE.custom_routing_function ) - routing_method_type = flashinfer.RoutingMethodType.DeepSeekV3 + routing_method_type = layer.routing_method_type if use_llama4_routing: - routing_method_type = flashinfer.RoutingMethodType.Llama4 + routing_method_type = RoutingMethodType.Llama4 + router_logits = ( + router_logits.to(torch.float32) + if routing_method_type == RoutingMethodType.DeepSeekV3 + else router_logits + ) routing_bias = e_score_correction_bias if routing_bias is not None: routing_bias = routing_bias.to(torch.bfloat16) out = flashinfer.fused_moe.trtllm_fp4_block_scale_moe( - routing_logits=router_logits - if use_llama4_routing - else router_logits.to(torch.float32), + routing_logits=router_logits, routing_bias=routing_bias, hidden_states=hidden_states_fp4, hidden_states_scale=hidden_states_scale_linear_fp4.view( @@ -1690,8 +1694,8 @@ class ModelOptNvFp4FusedMoE(FusedMoEMethodBase): output2_scale_scalar=layer.g2_alphas.data, num_experts=global_num_experts, top_k=top_k, - n_group=num_expert_group if num_expert_group is not None else 0, - topk_group=topk_group if topk_group is not None else 0, + n_group=num_expert_group, + topk_group=topk_group, intermediate_size=layer.intermediate_size_per_partition, local_expert_offset=layer.ep_rank * layer.local_num_experts, local_num_experts=layer.local_num_experts, diff --git a/vllm/model_executor/layers/quantization/utils/flashinfer_utils.py b/vllm/model_executor/layers/quantization/utils/flashinfer_utils.py index d9e9b42402712..f22e17945d1f6 100644 --- a/vllm/model_executor/layers/quantization/utils/flashinfer_utils.py +++ b/vllm/model_executor/layers/quantization/utils/flashinfer_utils.py @@ -291,5 +291,8 @@ def get_flashinfer_moe_backend() -> FlashinferMoeBackend: def is_flashinfer_supporting_global_sf(backend: FlashinferMoeBackend | None) -> bool: # TODO(shuw@nvidia): Update when new backends are added. - backends_supporting_global_sf = (FlashinferMoeBackend.CUTLASS,) + backends_supporting_global_sf = ( + FlashinferMoeBackend.CUTLASS, + FlashinferMoeBackend.TENSORRT_LLM, + ) return backend in backends_supporting_global_sf