mirror of
https://git.datalinker.icu/vllm-project/vllm.git
synced 2025-12-26 17:29:44 +08:00
[Performance][Fix] update nvfp4 code to support renorm routing (#28569)
Signed-off-by: jiahanc <173873397+jiahanc@users.noreply.github.com> Co-authored-by: Michael Goin <mgoin64@gmail.com>
This commit is contained in:
parent
80b6080ddc
commit
561253b37f
@ -15,6 +15,7 @@ from vllm.logger import init_logger
|
||||
from vllm.model_executor.layers.fused_moe.config import (
|
||||
FusedMoEConfig,
|
||||
FusedMoEQuantConfig,
|
||||
RoutingMethodType,
|
||||
fp8_w8a8_moe_quant_config,
|
||||
nvfp4_moe_quant_config,
|
||||
)
|
||||
@ -1657,16 +1658,19 @@ class ModelOptNvFp4FusedMoE(FusedMoEMethodBase):
|
||||
use_llama4_routing = (
|
||||
custom_routing_function is Llama4MoE.custom_routing_function
|
||||
)
|
||||
routing_method_type = flashinfer.RoutingMethodType.DeepSeekV3
|
||||
routing_method_type = layer.routing_method_type
|
||||
if use_llama4_routing:
|
||||
routing_method_type = flashinfer.RoutingMethodType.Llama4
|
||||
routing_method_type = RoutingMethodType.Llama4
|
||||
router_logits = (
|
||||
router_logits.to(torch.float32)
|
||||
if routing_method_type == RoutingMethodType.DeepSeekV3
|
||||
else router_logits
|
||||
)
|
||||
routing_bias = e_score_correction_bias
|
||||
if routing_bias is not None:
|
||||
routing_bias = routing_bias.to(torch.bfloat16)
|
||||
out = flashinfer.fused_moe.trtllm_fp4_block_scale_moe(
|
||||
routing_logits=router_logits
|
||||
if use_llama4_routing
|
||||
else router_logits.to(torch.float32),
|
||||
routing_logits=router_logits,
|
||||
routing_bias=routing_bias,
|
||||
hidden_states=hidden_states_fp4,
|
||||
hidden_states_scale=hidden_states_scale_linear_fp4.view(
|
||||
@ -1690,8 +1694,8 @@ class ModelOptNvFp4FusedMoE(FusedMoEMethodBase):
|
||||
output2_scale_scalar=layer.g2_alphas.data,
|
||||
num_experts=global_num_experts,
|
||||
top_k=top_k,
|
||||
n_group=num_expert_group if num_expert_group is not None else 0,
|
||||
topk_group=topk_group if topk_group is not None else 0,
|
||||
n_group=num_expert_group,
|
||||
topk_group=topk_group,
|
||||
intermediate_size=layer.intermediate_size_per_partition,
|
||||
local_expert_offset=layer.ep_rank * layer.local_num_experts,
|
||||
local_num_experts=layer.local_num_experts,
|
||||
|
||||
@ -291,5 +291,8 @@ def get_flashinfer_moe_backend() -> FlashinferMoeBackend:
|
||||
|
||||
def is_flashinfer_supporting_global_sf(backend: FlashinferMoeBackend | None) -> bool:
|
||||
# TODO(shuw@nvidia): Update when new backends are added.
|
||||
backends_supporting_global_sf = (FlashinferMoeBackend.CUTLASS,)
|
||||
backends_supporting_global_sf = (
|
||||
FlashinferMoeBackend.CUTLASS,
|
||||
FlashinferMoeBackend.TENSORRT_LLM,
|
||||
)
|
||||
return backend in backends_supporting_global_sf
|
||||
|
||||
Loading…
x
Reference in New Issue
Block a user