diff --git a/vllm/model_executor/models/ernie45_moe.py b/vllm/model_executor/models/ernie45_moe.py index 5824b0967e773..4780ea931ea50 100644 --- a/vllm/model_executor/models/ernie45_moe.py +++ b/vllm/model_executor/models/ernie45_moe.py @@ -109,8 +109,8 @@ class Ernie4_5_MoeMoE(nn.Module): layer_idx = extract_layer_index(prefix) self.layer_idx = layer_idx self.tp_size = get_tensor_model_parallel_world_size() - self.moe_num_shared_experts = getattr(config, "moe_num_shared_experts", - None) + self.has_shared_experts = (getattr(config, "moe_num_shared_experts", 0) + > 0) if self.tp_size > config.moe_num_experts: raise ValueError( @@ -137,7 +137,7 @@ class Ernie4_5_MoeMoE(nn.Module): prefix=f"{prefix}.experts", e_score_correction_bias=self.gate.e_score_correction_bias) - if self.moe_num_shared_experts is not None: + if self.has_shared_experts: intermediate_size = (config.moe_intermediate_size * config.moe_num_shared_experts) self.shared_experts = Ernie4_5_MoeMLP( @@ -153,7 +153,8 @@ class Ernie4_5_MoeMoE(nn.Module): orig_shape = hidden_states.shape hidden_dim = hidden_states.shape[-1] hidden_states = hidden_states.view(-1, hidden_dim) - if self.moe_num_shared_experts is not None: + shared_output = None + if self.has_shared_experts: shared_output = self.shared_experts(hidden_states) router_logits, _ = self.gate(hidden_states) @@ -161,7 +162,7 @@ class Ernie4_5_MoeMoE(nn.Module): final_hidden_states = self.experts(hidden_states=hidden_states, router_logits=router_logits) - if self.moe_num_shared_experts is not None and \ + if self.has_shared_experts and \ shared_output is not None: final_hidden_states = final_hidden_states + shared_output