mirror of
https://git.datalinker.icu/vllm-project/vllm.git
synced 2026-04-09 23:27:12 +08:00
[Refactor] Optimize select_experts (#28069)
Signed-off-by: yewentao256 <zhyanwentao@126.com>
This commit is contained in:
parent
3aaa94ac99
commit
5031cd5d55
@ -1246,7 +1246,6 @@ def eplb_map_to_physical_and_record(
|
||||
expert_load_view: torch.Tensor,
|
||||
logical_to_physical_map: torch.Tensor,
|
||||
logical_replica_count: torch.Tensor,
|
||||
indices_type: torch.dtype | None = None,
|
||||
) -> torch.Tensor:
|
||||
"""
|
||||
Map the logical expert ids to physical expert ids
|
||||
@ -1260,7 +1259,6 @@ def eplb_map_to_physical_and_record(
|
||||
expert_load_view: The expert load view.
|
||||
logical_to_physical_map: The logical to physical map.
|
||||
logical_replica_count: The logical replica count.
|
||||
indices_type: The indices type.
|
||||
|
||||
Returns:
|
||||
The physical expert ids.
|
||||
@ -1310,9 +1308,6 @@ def eplb_map_to_physical_and_record(
|
||||
index=topk_ids_flatten.long(),
|
||||
src=torch.ones_like(topk_ids_flatten).to(expert_load_view),
|
||||
)
|
||||
|
||||
if indices_type is not None:
|
||||
topk_ids = topk_ids.to(dtype=indices_type)
|
||||
return topk_ids
|
||||
|
||||
|
||||
|
||||
@ -68,7 +68,6 @@ else:
|
||||
expert_load_view: torch.Tensor,
|
||||
logical_to_physical_map: torch.Tensor,
|
||||
logical_replica_count: torch.Tensor,
|
||||
indices_type: torch.dtype | None,
|
||||
) -> torch.Tensor:
|
||||
# CPU fallback: no EPLB so just return as is
|
||||
return topk_ids
|
||||
@ -1509,8 +1508,6 @@ class FusedMoE(CustomOp):
|
||||
routed_scaling_factor=routed_scaling_factor,
|
||||
e_score_correction_bias=e_score_correction_bias,
|
||||
)
|
||||
if indices_type is not None:
|
||||
topk_ids = topk_ids.to(dtype=indices_type)
|
||||
elif e_score_correction_bias is not None:
|
||||
topk_weights, topk_ids = fused_topk_bias(
|
||||
hidden_states=hidden_states,
|
||||
@ -1519,7 +1516,7 @@ class FusedMoE(CustomOp):
|
||||
topk=top_k,
|
||||
renormalize=renormalize,
|
||||
)
|
||||
if routed_scaling_factor is not None:
|
||||
if routed_scaling_factor != 1.0:
|
||||
topk_weights *= routed_scaling_factor
|
||||
elif custom_routing_function is None:
|
||||
topk_weights, topk_ids, token_expert_indices = fused_topk(
|
||||
@ -1536,8 +1533,6 @@ class FusedMoE(CustomOp):
|
||||
topk=top_k,
|
||||
renormalize=renormalize,
|
||||
)
|
||||
if indices_type is not None:
|
||||
topk_ids = topk_ids.to(dtype=indices_type)
|
||||
|
||||
if enable_eplb:
|
||||
assert expert_load_view is not None
|
||||
@ -1549,9 +1544,11 @@ class FusedMoE(CustomOp):
|
||||
expert_load_view=expert_load_view,
|
||||
logical_to_physical_map=logical_to_physical_map,
|
||||
logical_replica_count=logical_replica_count,
|
||||
indices_type=indices_type,
|
||||
)
|
||||
|
||||
if (indices_type is not None) and topk_ids.dtype != indices_type:
|
||||
topk_ids = topk_ids.to(dtype=indices_type)
|
||||
|
||||
assert topk_ids.dtype == indices_type or indices_type is None
|
||||
|
||||
# Compute zero expert result if needed
|
||||
|
||||
@ -1706,7 +1706,7 @@ class ModelOptNvFp4FusedMoE(FusedMoEMethodBase):
|
||||
intermediate_size=layer.intermediate_size_per_partition,
|
||||
local_expert_offset=layer.ep_rank * layer.local_num_experts,
|
||||
local_num_experts=layer.local_num_experts,
|
||||
routed_scaling_factor=None,
|
||||
routed_scaling_factor=1.0,
|
||||
tile_tokens_dim=None,
|
||||
routing_method_type=routing_method_type,
|
||||
do_finalize=True,
|
||||
|
||||
@ -118,7 +118,7 @@ class FlashConfig(PretrainedConfig):
|
||||
router_dtype="float32",
|
||||
router_bias=False,
|
||||
topk_method=None,
|
||||
routed_scaling_factor=None,
|
||||
routed_scaling_factor=1.0,
|
||||
zero_expert_num=0,
|
||||
zero_expert_type=None,
|
||||
nextn_use_scmoe=False,
|
||||
|
||||
@ -625,7 +625,7 @@ class OpenPanguDecoderLayer(nn.Module):
|
||||
bias=getattr(config, "mlp_bias", False),
|
||||
prefix=f"{prefix}.mlp",
|
||||
)
|
||||
self.routed_scaling_factor = getattr(config, "routed_scaling_factor", None)
|
||||
self.routed_scaling_factor = getattr(config, "routed_scaling_factor", 1.0)
|
||||
self.num_hidden_layers = config.num_hidden_layers
|
||||
self.first_k_dense_replace = getattr(
|
||||
config, "first_k_dense_replace", self.num_hidden_layers
|
||||
|
||||
Loading…
x
Reference in New Issue
Block a user