[Refactor] Optimize select_experts (#28069)

Signed-off-by: yewentao256 <zhyanwentao@126.com>
This commit is contained in:
Wentao Ye 2025-11-19 18:53:15 -05:00 committed by GitHub
parent 3aaa94ac99
commit 5031cd5d55
No known key found for this signature in database
GPG Key ID: B5690EEEBB952194
5 changed files with 7 additions and 15 deletions

View File

@ -1246,7 +1246,6 @@ def eplb_map_to_physical_and_record(
expert_load_view: torch.Tensor,
logical_to_physical_map: torch.Tensor,
logical_replica_count: torch.Tensor,
indices_type: torch.dtype | None = None,
) -> torch.Tensor:
"""
Map the logical expert ids to physical expert ids
@ -1260,7 +1259,6 @@ def eplb_map_to_physical_and_record(
expert_load_view: The expert load view.
logical_to_physical_map: The logical to physical map.
logical_replica_count: The logical replica count.
indices_type: The indices type.
Returns:
The physical expert ids.
@ -1310,9 +1308,6 @@ def eplb_map_to_physical_and_record(
index=topk_ids_flatten.long(),
src=torch.ones_like(topk_ids_flatten).to(expert_load_view),
)
if indices_type is not None:
topk_ids = topk_ids.to(dtype=indices_type)
return topk_ids

View File

@ -68,7 +68,6 @@ else:
expert_load_view: torch.Tensor,
logical_to_physical_map: torch.Tensor,
logical_replica_count: torch.Tensor,
indices_type: torch.dtype | None,
) -> torch.Tensor:
# CPU fallback: no EPLB so just return as is
return topk_ids
@ -1509,8 +1508,6 @@ class FusedMoE(CustomOp):
routed_scaling_factor=routed_scaling_factor,
e_score_correction_bias=e_score_correction_bias,
)
if indices_type is not None:
topk_ids = topk_ids.to(dtype=indices_type)
elif e_score_correction_bias is not None:
topk_weights, topk_ids = fused_topk_bias(
hidden_states=hidden_states,
@ -1519,7 +1516,7 @@ class FusedMoE(CustomOp):
topk=top_k,
renormalize=renormalize,
)
if routed_scaling_factor is not None:
if routed_scaling_factor != 1.0:
topk_weights *= routed_scaling_factor
elif custom_routing_function is None:
topk_weights, topk_ids, token_expert_indices = fused_topk(
@ -1536,8 +1533,6 @@ class FusedMoE(CustomOp):
topk=top_k,
renormalize=renormalize,
)
if indices_type is not None:
topk_ids = topk_ids.to(dtype=indices_type)
if enable_eplb:
assert expert_load_view is not None
@ -1549,9 +1544,11 @@ class FusedMoE(CustomOp):
expert_load_view=expert_load_view,
logical_to_physical_map=logical_to_physical_map,
logical_replica_count=logical_replica_count,
indices_type=indices_type,
)
if (indices_type is not None) and topk_ids.dtype != indices_type:
topk_ids = topk_ids.to(dtype=indices_type)
assert topk_ids.dtype == indices_type or indices_type is None
# Compute zero expert result if needed

View File

@ -1706,7 +1706,7 @@ class ModelOptNvFp4FusedMoE(FusedMoEMethodBase):
intermediate_size=layer.intermediate_size_per_partition,
local_expert_offset=layer.ep_rank * layer.local_num_experts,
local_num_experts=layer.local_num_experts,
routed_scaling_factor=None,
routed_scaling_factor=1.0,
tile_tokens_dim=None,
routing_method_type=routing_method_type,
do_finalize=True,

View File

@ -118,7 +118,7 @@ class FlashConfig(PretrainedConfig):
router_dtype="float32",
router_bias=False,
topk_method=None,
routed_scaling_factor=None,
routed_scaling_factor=1.0,
zero_expert_num=0,
zero_expert_type=None,
nextn_use_scmoe=False,

View File

@ -625,7 +625,7 @@ class OpenPanguDecoderLayer(nn.Module):
bias=getattr(config, "mlp_bias", False),
prefix=f"{prefix}.mlp",
)
self.routed_scaling_factor = getattr(config, "routed_scaling_factor", None)
self.routed_scaling_factor = getattr(config, "routed_scaling_factor", 1.0)
self.num_hidden_layers = config.num_hidden_layers
self.first_k_dense_replace = getattr(
config, "first_k_dense_replace", self.num_hidden_layers