mirror of
https://git.datalinker.icu/vllm-project/vllm.git
synced 2025-12-18 03:15:01 +08:00
Fix routing_bias dtype (#25711)
Signed-off-by: Shu Wang. <shuw@nvidia.com>
This commit is contained in:
parent
57329a8c01
commit
081b5594a2
@ -1454,10 +1454,13 @@ class ModelOptNvFp4FusedMoE(FusedMoEMethodBase):
|
|||||||
routing_method_type = flashinfer.RoutingMethodType.DeepSeekV3
|
routing_method_type = flashinfer.RoutingMethodType.DeepSeekV3
|
||||||
if use_llama4_routing:
|
if use_llama4_routing:
|
||||||
routing_method_type = flashinfer.RoutingMethodType.Llama4
|
routing_method_type = flashinfer.RoutingMethodType.Llama4
|
||||||
|
routing_bias = e_score_correction_bias
|
||||||
|
if routing_bias is not None:
|
||||||
|
routing_bias = routing_bias.to(torch.bfloat16)
|
||||||
out = flashinfer.fused_moe.trtllm_fp4_block_scale_moe(
|
out = flashinfer.fused_moe.trtllm_fp4_block_scale_moe(
|
||||||
routing_logits=router_logits
|
routing_logits=router_logits
|
||||||
if use_llama4_routing else router_logits.to(torch.float32),
|
if use_llama4_routing else router_logits.to(torch.float32),
|
||||||
routing_bias=e_score_correction_bias,
|
routing_bias=routing_bias,
|
||||||
hidden_states=hidden_states_fp4,
|
hidden_states=hidden_states_fp4,
|
||||||
hidden_states_scale=hidden_states_scale_linear_fp4.view(
|
hidden_states_scale=hidden_states_scale_linear_fp4.view(
|
||||||
torch.float8_e4m3fn).flatten(),
|
torch.float8_e4m3fn).flatten(),
|
||||||
|
|||||||
Loading…
x
Reference in New Issue
Block a user