[Bugfix] Fix Ernie4_5_MoeForCausalLM shared experts (#21717)

Signed-off-by: Jee Jee Li <pandaleefree@gmail.com>
This commit is contained in:
Jee Jee Li 2025-07-28 19:02:25 +08:00 committed by GitHub
parent 2cc571199b
commit 1b769dccf3
No known key found for this signature in database
GPG Key ID: B5690EEEBB952194

View File

@ -109,8 +109,8 @@ class Ernie4_5_MoeMoE(nn.Module):
layer_idx = extract_layer_index(prefix)
self.layer_idx = layer_idx
self.tp_size = get_tensor_model_parallel_world_size()
self.moe_num_shared_experts = getattr(config, "moe_num_shared_experts",
None)
self.has_shared_experts = (getattr(config, "moe_num_shared_experts", 0)
> 0)
if self.tp_size > config.moe_num_experts:
raise ValueError(
@ -137,7 +137,7 @@ class Ernie4_5_MoeMoE(nn.Module):
prefix=f"{prefix}.experts",
e_score_correction_bias=self.gate.e_score_correction_bias)
if self.moe_num_shared_experts is not None:
if self.has_shared_experts:
intermediate_size = (config.moe_intermediate_size *
config.moe_num_shared_experts)
self.shared_experts = Ernie4_5_MoeMLP(
@ -153,7 +153,8 @@ class Ernie4_5_MoeMoE(nn.Module):
orig_shape = hidden_states.shape
hidden_dim = hidden_states.shape[-1]
hidden_states = hidden_states.view(-1, hidden_dim)
if self.moe_num_shared_experts is not None:
shared_output = None
if self.has_shared_experts:
shared_output = self.shared_experts(hidden_states)
router_logits, _ = self.gate(hidden_states)
@ -161,7 +162,7 @@ class Ernie4_5_MoeMoE(nn.Module):
final_hidden_states = self.experts(hidden_states=hidden_states,
router_logits=router_logits)
if self.moe_num_shared_experts is not None and \
if self.has_shared_experts and \
shared_output is not None:
final_hidden_states = final_hidden_states + shared_output