mirror of
https://git.datalinker.icu/vllm-project/vllm.git
synced 2026-01-08 13:16:32 +08:00
[Bugfix] Fix Ernie4_5_MoeForCausalLM shared experts (#21717)
Signed-off-by: Jee Jee Li <pandaleefree@gmail.com>
This commit is contained in:
parent
2cc571199b
commit
1b769dccf3
@ -109,8 +109,8 @@ class Ernie4_5_MoeMoE(nn.Module):
|
||||
layer_idx = extract_layer_index(prefix)
|
||||
self.layer_idx = layer_idx
|
||||
self.tp_size = get_tensor_model_parallel_world_size()
|
||||
self.moe_num_shared_experts = getattr(config, "moe_num_shared_experts",
|
||||
None)
|
||||
self.has_shared_experts = (getattr(config, "moe_num_shared_experts", 0)
|
||||
> 0)
|
||||
|
||||
if self.tp_size > config.moe_num_experts:
|
||||
raise ValueError(
|
||||
@ -137,7 +137,7 @@ class Ernie4_5_MoeMoE(nn.Module):
|
||||
prefix=f"{prefix}.experts",
|
||||
e_score_correction_bias=self.gate.e_score_correction_bias)
|
||||
|
||||
if self.moe_num_shared_experts is not None:
|
||||
if self.has_shared_experts:
|
||||
intermediate_size = (config.moe_intermediate_size *
|
||||
config.moe_num_shared_experts)
|
||||
self.shared_experts = Ernie4_5_MoeMLP(
|
||||
@ -153,7 +153,8 @@ class Ernie4_5_MoeMoE(nn.Module):
|
||||
orig_shape = hidden_states.shape
|
||||
hidden_dim = hidden_states.shape[-1]
|
||||
hidden_states = hidden_states.view(-1, hidden_dim)
|
||||
if self.moe_num_shared_experts is not None:
|
||||
shared_output = None
|
||||
if self.has_shared_experts:
|
||||
shared_output = self.shared_experts(hidden_states)
|
||||
|
||||
router_logits, _ = self.gate(hidden_states)
|
||||
@ -161,7 +162,7 @@ class Ernie4_5_MoeMoE(nn.Module):
|
||||
final_hidden_states = self.experts(hidden_states=hidden_states,
|
||||
router_logits=router_logits)
|
||||
|
||||
if self.moe_num_shared_experts is not None and \
|
||||
if self.has_shared_experts and \
|
||||
shared_output is not None:
|
||||
final_hidden_states = final_hidden_states + shared_output
|
||||
|
||||
|
||||
Loading…
x
Reference in New Issue
Block a user